1
1
#include " llama-context.h"
2
2
3
3
#include " llama-impl.h"
4
+ #include " llama-batch.h"
4
5
#include " llama-io.h"
5
6
#include " llama-memory.h"
6
7
#include " llama-mmap.h"
18
19
llama_context::llama_context (
19
20
const llama_model & model,
20
21
llama_context_params params) :
21
- model(model) {
22
+ model(model),
23
+ batch_allocr(std::make_unique<llama_batch_allocr>()) {
22
24
LLAMA_LOG_INFO (" %s: constructing llama_context\n " , __func__);
23
25
24
26
t_start_us = model.t_start_us ;
@@ -494,7 +496,7 @@ float * llama_context::get_logits() {
494
496
}
495
497
496
498
float * llama_context::get_logits_ith (int32_t i) {
497
- int32_t j = -1 ;
499
+ int64_t j = -1 ;
498
500
499
501
try {
500
502
if (logits == nullptr ) {
@@ -517,7 +519,7 @@ float * llama_context::get_logits_ith(int32_t i) {
517
519
}
518
520
if (j >= n_outputs) {
519
521
// This should not happen
520
- throw std::runtime_error (format (" corrupt output buffer (j=%d , n_outputs=%d)" , j, n_outputs));
522
+ throw std::runtime_error (format (" corrupt output buffer (j=%" PRId64 " , n_outputs=%d)" , j, n_outputs));
521
523
}
522
524
523
525
return logits + j*model.vocab .n_tokens ();
@@ -536,7 +538,7 @@ float * llama_context::get_embeddings() {
536
538
}
537
539
538
540
float * llama_context::get_embeddings_ith (int32_t i) {
539
- int32_t j = -1 ;
541
+ int64_t j = -1 ;
540
542
541
543
try {
542
544
if (embd == nullptr ) {
@@ -559,7 +561,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
559
561
}
560
562
if (j >= n_outputs) {
561
563
// This should not happen
562
- throw std::runtime_error (format (" corrupt output buffer (j=%d , n_outputs=%d)" , j, n_outputs));
564
+ throw std::runtime_error (format (" corrupt output buffer (j=%" PRId64 " , n_outputs=%d)" , j, n_outputs));
563
565
}
564
566
565
567
return embd + j*model.hparams .n_embd ;
@@ -727,18 +729,19 @@ int llama_context::encode(llama_batch & inp_batch) {
727
729
728
730
// temporary allocate memory for the input batch if needed
729
731
// note: during encode, we always pass the full sequence starting from pos = 0
730
- llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : 0 );
732
+ batch_allocr-> init (inp_batch, inp_batch.pos ? -1 : 0 );
731
733
732
- const llama_batch & batch = batch_allocr.batch ;
733
- const int32_t n_tokens = batch.n_tokens ;
734
+ const llama_batch & batch = batch_allocr->get_batch ();
735
+
736
+ const uint32_t n_tokens = batch.n_tokens ;
734
737
735
738
const auto & hparams = model.hparams ;
736
739
737
740
GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
738
741
739
742
// TODO: move the validation to the llama_batch_allocr
740
743
if (batch.token ) {
741
- for (int32_t i = 0 ; i < n_tokens; ++i) {
744
+ for (uint32_t i = 0 ; i < n_tokens; ++i) {
742
745
if (batch.token [i] < 0 || (uint32_t ) batch.token [i] >= model.vocab .n_tokens ()) {
743
746
LLAMA_LOG_ERROR (" %s: invalid token[%d] = %d\n " , __func__, i, batch.token [i]);
744
747
return -1 ;
@@ -775,7 +778,7 @@ int llama_context::encode(llama_batch & inp_batch) {
775
778
return -2 ;
776
779
};
777
780
778
- for (int32_t i = 0 ; i < n_tokens; ++i) {
781
+ for (uint32_t i = 0 ; i < n_tokens; ++i) {
779
782
output_ids[i] = i;
780
783
}
781
784
@@ -831,7 +834,8 @@ int llama_context::encode(llama_batch & inp_batch) {
831
834
832
835
GGML_ASSERT (!ubatch.equal_seqs ); // TODO: handle equal splits
833
836
834
- for (int32_t i = 0 ; i < n_tokens; i++) {
837
+ // TODO: fix sequence indexing
838
+ for (uint32_t i = 0 ; i < n_tokens; i++) {
835
839
const llama_seq_id seq_id = ubatch.seq_id [i][0 ];
836
840
if (embd_seq_out.find (seq_id) != embd_seq_out.end ()) {
837
841
continue ;
@@ -881,7 +885,7 @@ int llama_context::encode(llama_batch & inp_batch) {
881
885
// TODO: the seuqence indexing here is likely not correct in the general case
882
886
// probably works only for split_simple
883
887
cross.seq_ids_enc .resize (n_tokens);
884
- for (int32_t i = 0 ; i < n_tokens; i++) {
888
+ for (uint32_t i = 0 ; i < n_tokens; i++) {
885
889
cross.seq_ids_enc [i].clear ();
886
890
for (int s = 0 ; s < ubatch.n_seq_id [i]; s++) {
887
891
llama_seq_id seq_id = ubatch.seq_id [i][s];
@@ -912,30 +916,30 @@ int llama_context::decode(llama_batch & inp_batch) {
912
916
}
913
917
914
918
// temporary allocate memory for the input batch if needed
915
- llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max (0 ) + 1 );
919
+ batch_allocr-> init (inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max (0 ) + 1 );
916
920
917
- const llama_batch & batch = batch_allocr. batch ;
921
+ const llama_batch & batch = batch_allocr-> get_batch () ;
918
922
919
923
const auto & vocab = model.vocab ;
920
924
const auto & hparams = model.hparams ;
921
925
922
926
const int32_t n_vocab = vocab.n_tokens ();
927
+ const int64_t n_embd = hparams.n_embd ;
923
928
924
- const int64_t n_tokens_all = batch.n_tokens ;
925
- const int64_t n_embd = hparams.n_embd ;
929
+ const uint32_t n_tokens_all = batch.n_tokens ;
926
930
927
931
GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
928
932
929
933
// TODO: move the validation to the llama_batch_allocr
930
934
if (batch.token ) {
931
- for (int64_t i = 0 ; i < n_tokens_all; ++i) {
935
+ for (uint32_t i = 0 ; i < n_tokens_all; ++i) {
932
936
if (batch.token [i] < 0 || (uint32_t ) batch.token [i] >= model.vocab .n_tokens ()) {
933
- LLAMA_LOG_ERROR (" %s: invalid token[%" PRId64 " ] = %d\n " , __func__, i, batch.token [i]);
937
+ LLAMA_LOG_ERROR (" %s: invalid token[%d ] = %d\n " , __func__, i, batch.token [i]);
934
938
return -1 ;
935
939
}
936
940
937
941
if (batch.seq_id && (batch.seq_id [i][0 ] < 0 || batch.seq_id [i][0 ] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
938
- LLAMA_LOG_ERROR (" %s: invalid seq_id[%" PRId64 " ] = %d >= %d\n " , __func__, i, batch.seq_id [i][0 ], LLAMA_MAX_PARALLEL_SEQUENCES);
942
+ LLAMA_LOG_ERROR (" %s: invalid seq_id[%d ] = %d >= %d\n " , __func__, i, batch.seq_id [i][0 ], LLAMA_MAX_PARALLEL_SEQUENCES);
939
943
return -1 ;
940
944
}
941
945
}
@@ -944,7 +948,7 @@ int llama_context::decode(llama_batch & inp_batch) {
944
948
// this indicates we are doing pooled embedding
945
949
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
946
950
947
- int64_t n_outputs_all = 0 ;
951
+ uint32_t n_outputs_all = 0 ;
948
952
949
953
// count outputs
950
954
for (uint32_t i = 0 ; i < n_tokens_all; ++i) {
@@ -954,7 +958,7 @@ int llama_context::decode(llama_batch & inp_batch) {
954
958
if (embd_pooled) {
955
959
// require that all tokens are output
956
960
if (n_outputs_all != n_tokens_all) {
957
- LLAMA_LOG_ERROR (" %s: pooled embedding requires that all tokens are output (n_outputs_all = %" PRId64 " , n_tokens_all = %" PRId64 " )\n " ,
961
+ LLAMA_LOG_ERROR (" %s: pooled embedding requires that all tokens are output (n_outputs_all = %d , n_tokens_all = %d )\n " ,
958
962
__func__, n_outputs_all, n_tokens_all);
959
963
return -1 ;
960
964
}
@@ -1024,7 +1028,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1024
1028
1025
1029
// reserve output buffer
1026
1030
if (output_reserve (n_outputs_all) < n_outputs_all) {
1027
- LLAMA_LOG_ERROR (" %s: could not reserve space for batch with %" PRId64 " outputs\n " , __func__, n_outputs_all);
1031
+ LLAMA_LOG_ERROR (" %s: could not reserve space for batch with %d outputs\n " , __func__, n_outputs_all);
1028
1032
return -2 ;
1029
1033
};
1030
1034
@@ -1063,6 +1067,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1063
1067
pos_min[s] = std::numeric_limits<llama_pos>::max ();
1064
1068
}
1065
1069
1070
+ // TODO: fix sequence indexing
1066
1071
for (uint32_t i = 0 ; i < ubatch.n_tokens ; ++i) {
1067
1072
const auto & seq_id = ubatch.seq_id [i][0 ];
1068
1073
@@ -1176,14 +1181,14 @@ int llama_context::decode(llama_batch & inp_batch) {
1176
1181
n_outputs = n_outputs_all;
1177
1182
1178
1183
// set output mappings
1179
- {
1184
+ if (n_outputs > 0 ) {
1180
1185
bool sorted_output = true ;
1181
1186
1182
1187
auto & out_ids = mstate->out_ids ();
1183
1188
1184
- GGML_ASSERT (out_ids.size () == (size_t ) n_outputs_all );
1189
+ GGML_ASSERT (out_ids.size () == (size_t ) n_outputs );
1185
1190
1186
- for (int64_t i = 0 ; i < n_outputs_all ; ++i) {
1191
+ for (int64_t i = 0 ; i < n_outputs ; ++i) {
1187
1192
int64_t out_id = out_ids[i];
1188
1193
output_ids[out_id] = i;
1189
1194
if (out_id != i) {
@@ -1195,20 +1200,22 @@ int llama_context::decode(llama_batch & inp_batch) {
1195
1200
// note: this is mostly relevant for recurrent models atm
1196
1201
if (!sorted_output) {
1197
1202
const uint32_t n_vocab = model.vocab .n_tokens ();
1198
- const uint32_t n_embd = model.hparams .n_embd ;
1203
+ const uint64_t n_embd = model.hparams .n_embd ;
1199
1204
1200
1205
GGML_ASSERT ((size_t ) n_outputs == out_ids.size ());
1201
1206
1202
1207
// TODO: is there something more efficient which also minimizes swaps?
1203
1208
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
1204
- for (int32_t i = 0 ; i < n_outputs - 1 ; ++i) {
1205
- int32_t j_min = i;
1206
- for (int32_t j = i + 1 ; j < n_outputs; ++j) {
1209
+ for (uint32_t i = 0 ; i < n_outputs - 1 ; ++i) {
1210
+ uint32_t j_min = i;
1211
+ for (uint32_t j = i + 1 ; j < n_outputs; ++j) {
1207
1212
if (out_ids[j] < out_ids[j_min]) {
1208
1213
j_min = j;
1209
1214
}
1210
1215
}
1211
- if (j_min == i) { continue ; }
1216
+ if (j_min == i) {
1217
+ continue ;
1218
+ }
1212
1219
std::swap (out_ids[i], out_ids[j_min]);
1213
1220
if (logits_size > 0 ) {
1214
1221
for (uint32_t k = 0 ; k < n_vocab; k++) {
@@ -1221,8 +1228,10 @@ int llama_context::decode(llama_batch & inp_batch) {
1221
1228
}
1222
1229
}
1223
1230
}
1231
+
1224
1232
std::fill (output_ids.begin (), output_ids.end (), -1 );
1225
- for (int32_t i = 0 ; i < n_outputs; ++i) {
1233
+
1234
+ for (uint32_t i = 0 ; i < n_outputs; ++i) {
1226
1235
output_ids[out_ids[i]] = i;
1227
1236
}
1228
1237
}
@@ -1242,7 +1251,7 @@ int llama_context::decode(llama_batch & inp_batch) {
1242
1251
// output
1243
1252
//
1244
1253
1245
- int32_t llama_context::output_reserve (int32_t n_outputs) {
1254
+ uint32_t llama_context::output_reserve (int32_t n_outputs) {
1246
1255
const auto & hparams = model.hparams ;
1247
1256
const auto & vocab = model.vocab ;
1248
1257
@@ -1308,8 +1317,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
1308
1317
// set all ids as invalid (negative)
1309
1318
std::fill (output_ids.begin (), output_ids.end (), -1 );
1310
1319
1311
- this ->n_outputs = 0 ;
1312
- this ->n_outputs_max = n_outputs_max;
1320
+ this ->n_outputs = 0 ;
1313
1321
1314
1322
return n_outputs_max;
1315
1323
}
@@ -1800,14 +1808,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
1800
1808
1801
1809
std::vector<int32_t > w_output_pos;
1802
1810
1803
- GGML_ASSERT (n_outputs <= n_outputs_max);
1804
-
1805
1811
w_output_pos.resize (n_outputs);
1806
1812
1807
1813
// build a more compact representation of the output ids
1808
1814
for (size_t i = 0 ; i < n_batch (); ++i) {
1809
1815
// map an output id to a position in the batch
1810
- int32_t pos = output_ids[i];
1816
+ int64_t pos = output_ids[i];
1811
1817
if (pos >= 0 ) {
1812
1818
GGML_ASSERT (pos < n_outputs);
1813
1819
w_output_pos[pos] = i;
@@ -2082,7 +2088,7 @@ void llama_context::opt_epoch_iter(
2082
2088
2083
2089
embd_seq.clear ();
2084
2090
2085
- int64_t n_outputs_all = n_tokens_all;
2091
+ uint32_t n_outputs_all = n_tokens_all;
2086
2092
2087
2093
auto mstate = memory->init_batch (batch, cparams.n_ubatch , embd_pooled);
2088
2094
if (!mstate || mstate->get_status () != LLAMA_MEMORY_STATUS_SUCCESS) {
@@ -2092,7 +2098,7 @@ void llama_context::opt_epoch_iter(
2092
2098
2093
2099
// reserve output buffer
2094
2100
if (output_reserve (n_outputs_all) < n_outputs_all) {
2095
- LLAMA_LOG_ERROR (" %s: could not reserve space for batch with %" PRId64 " outputs\n " , __func__, n_outputs_all);
2101
+ LLAMA_LOG_ERROR (" %s: could not reserve space for batch with %d outputs\n " , __func__, n_outputs_all);
2096
2102
GGML_ABORT (" TODO: handle this error" );
2097
2103
};
2098
2104
0 commit comments