@@ -72,33 +72,36 @@ struct llama_model {
72
72
};
73
73
struct llama_state
74
74
{
75
- int64_t t_sample_us = 0 ;
76
- int64_t t_predict_us = 0 ;
75
+ // Timers
76
+ struct timing {
77
+ int64_t t_load_us = 0 ;
77
78
78
- std::vector<float > logits;
79
+ int64_t t_sample_us = 0 ;
80
+ int64_t t_predict_us = 0 ;
81
+ } timing;
79
82
80
- mutable std::mt19937 rng;
83
+ // Random number generator
84
+ std::mt19937 rng{};
81
85
86
+ // Tokens
82
87
std::vector<gpt_vocab::id> embd{};
88
+ std::vector<gpt_vocab::id> embd_inp{};
89
+ std::vector<gpt_vocab::id> last_n_tokens{};
83
90
91
+ // Logits from inference
92
+ std::vector<float > logits{};
93
+
94
+ // Counters
84
95
int input_consumed = 0 ;
85
- std::vector<gpt_vocab::id> embd_inp;
86
- std::vector<gpt_vocab::id> last_n_tokens;
87
96
int remaining_tokens = 0 ;
88
97
int n_past = 0 ;
89
98
size_t mem_per_token = 0 ;
90
- bool is_initialized = false ;
91
- llama_state () {}
92
99
93
- bool has_more_input () const {
94
- return input_consumed < embd_inp.size ();
95
- }
100
+ // Flag set after initialization
101
+ bool is_initialized = false ;
96
102
};
97
103
struct llama_context
98
104
{
99
- int64_t t_load_us = 0 ;
100
- int64_t t_start_us = 0 ;
101
-
102
105
ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 or FP16)
103
106
104
107
llama_model model{};
@@ -111,8 +114,6 @@ struct llama_context
111
114
llama_context () = default ;
112
115
// constructor
113
116
llama_context (llama_model&& model, gpt_vocab&& vocab, const gpt_params& params):
114
- t_load_us (0 ),
115
- t_start_us (0 ),
116
117
wtype (ggml_type::GGML_TYPE_F16),
117
118
model (std::move(model)),
118
119
vocab (std::move(vocab)),
@@ -829,7 +830,8 @@ bool llama_context_is_finished(const llama_context& ctx)
829
830
return ctx.state ->remaining_tokens <= 0 ;
830
831
}
831
832
const std::vector<gpt_vocab::id> llama_tokenize_text (const llama_context& ctx, const std::string& text) {
832
- return llama_tokenize (ctx.vocab , text, true );
833
+ // Make sure that the "beginning of string" token is not prefixed to the text
834
+ return llama_tokenize (ctx.vocab , text, false );
833
835
}
834
836
const std::vector<gpt_vocab::id>& llama_context_get_last_n_tokens (const llama_context& ctx) {
835
837
return ctx.state ->last_n_tokens ;
@@ -847,7 +849,8 @@ llama_context* llama_init_from_params(const gpt_params& params) {
847
849
return nullptr ;
848
850
}
849
851
llama_context* ctx = new llama_context (std::move (model), std::move (vocab), params);
850
- ctx->t_load_us = t_end - t_start;
852
+ ctx->state ->timing .t_load_us = t_end - t_start;
853
+ ctx->state ->rng = std::mt19937 (params.seed );
851
854
return ctx;
852
855
}
853
856
void llama_free_context (llama_context* ctx) {
@@ -874,7 +877,7 @@ const char * llama_print_system_info(void) {
874
877
return s.c_str ();
875
878
}
876
879
877
- void llama_print_context_info (const llama_context& ctx)
880
+ void llama_print_startup_stats (const llama_context& ctx)
878
881
{
879
882
const gpt_params& params = ctx.params ;
880
883
const std::vector<gpt_vocab::id>& embd_inp = ctx.state ->embd_inp ;
@@ -897,9 +900,9 @@ void llama_print_end_stats(const llama_context& ctx)
897
900
const llama_state& state = *ctx.state ;
898
901
fprintf (stderr, " \n\n " );
899
902
fprintf (stderr, " %s: mem per token = %8zu bytes\n " , __func__, state.mem_per_token );
900
- fprintf (stderr, " %s: load time = %8.2f ms\n " , __func__, ctx .t_load_us /1000 .0f );
901
- fprintf (stderr, " %s: sample time = %8.2f ms\n " , __func__, state.t_sample_us /1000 .0f );
902
- fprintf (stderr, " %s: predict time = %8.2f ms / %.2f ms per token\n " , __func__, state.t_predict_us /1000 .0f , state.t_predict_us /1000 .0f /state.n_past );
903
+ fprintf (stderr, " %s: load time = %8.2f ms\n " , __func__, state. timing .t_load_us /1000 .0f );
904
+ fprintf (stderr, " %s: sample time = %8.2f ms\n " , __func__, state.timing . t_sample_us /1000 .0f );
905
+ fprintf (stderr, " %s: predict time = %8.2f ms / %.2f ms per token\n " , __func__, state.timing . t_predict_us /1000 .0f , state. timing .t_predict_us /1000 .0f /state.n_past );
903
906
}
904
907
// evaluate the transformer
905
908
//
@@ -1137,25 +1140,26 @@ bool llama_eval(
1137
1140
return true ;
1138
1141
}
1139
1142
1140
- bool llama_update_context_with_prompt (llama_context& ctx, const std::string& text, bool clear_existing) {
1143
+ void llama_update_input (llama_context& ctx, const std::string& text)
1144
+ {
1141
1145
llama_state& state = *ctx.state ;
1142
1146
llama_model& model = ctx.model ;
1143
1147
const gpt_params& params = ctx.params ;
1144
1148
1145
- if (clear_existing) {
1146
- state.embd .clear ();
1147
- state.input_consumed = 0 ;
1148
- state.embd_inp .clear ();
1149
- state.last_n_tokens .clear ();
1150
- state.remaining_tokens = 0 ;
1151
- state.n_past = 0 ;
1152
- }
1153
-
1154
1149
std::vector<gpt_vocab::id> line_inp = llama_tokenize_text (ctx, text);
1150
+
1155
1151
state.embd_inp .insert (state.embd_inp .end (), line_inp.begin (), line_inp.end ());
1152
+ state.remaining_tokens -= line_inp.size ();
1153
+ }
1154
+
1155
+ bool llama_prepare_context (llama_context& ctx)
1156
+ {
1157
+ llama_state& state = *ctx.state ;
1158
+ llama_model& model = ctx.model ;
1159
+ gpt_params& params = ctx.params ;
1156
1160
1157
1161
int n_predict = std::min (params.n_predict , model.hparams .n_ctx - (int ) state.embd_inp .size ());
1158
- state. remaining_tokens = n_predict;
1162
+ params. n_predict = n_predict;
1159
1163
1160
1164
// determine the required inference memory per token:
1161
1165
state.mem_per_token = 0 ;
@@ -1168,8 +1172,9 @@ bool llama_update_context_with_prompt(llama_context& ctx, const std::string& tex
1168
1172
int last_n_size = params.repeat_last_n ;
1169
1173
state.last_n_tokens = std::vector<gpt_vocab::id>(last_n_size);
1170
1174
std::fill (state.last_n_tokens .begin (), state.last_n_tokens .end (), 0 );
1171
-
1172
1175
state.is_initialized = true ;
1176
+ state.remaining_tokens = params.n_predict ;
1177
+ state.input_consumed = 0 ;
1173
1178
return true ;
1174
1179
}
1175
1180
@@ -1180,36 +1185,54 @@ void llama_ingest_input_batch(llama_context& ctx)
1180
1185
llama_state& state = *ctx.state ;
1181
1186
const gpt_params& params = ctx.params ;
1182
1187
1183
- // Copy at most n_batch elements from embd_inp to embd
1184
- size_t num_copied = std::min ((size_t ) params.n_batch , state.embd_inp .size () - state.input_consumed );
1185
- std::copy (state.embd_inp .begin () + state.input_consumed ,
1186
- state.embd_inp .begin () + state.input_consumed + num_copied,
1187
- std::back_inserter (state.embd ));
1188
- state.input_consumed += num_copied;
1189
-
1190
- // Copy the last `repeat_last_n` elements copied into embd to last_n_tokens
1191
- size_t num_copied_last_n = std::min (num_copied, (size_t ) params.repeat_last_n );
1192
- state.last_n_tokens .erase (state.last_n_tokens .begin (), state.last_n_tokens .begin ()+num_copied_last_n);
1193
- state.last_n_tokens .insert (state.last_n_tokens .end (), state.embd .end () - num_copied_last_n, state.embd .end ());
1188
+ // some user input remains from prompt or interaction, forward it to processing
1189
+ while (state.embd_inp .size () > state.input_consumed ) {
1190
+ state.embd .push_back (state.embd_inp [state.input_consumed ]);
1191
+ state.last_n_tokens .erase (state.last_n_tokens .begin ());
1192
+ state.last_n_tokens .push_back (state.embd_inp [state.input_consumed ]);
1193
+ ++state.input_consumed ;
1194
+ if (state.embd .size () > params.n_batch ) {
1195
+ break ;
1196
+ }
1197
+ }
1198
+ // // Copy at most n_batch elements from embd_inp to embd
1199
+ // size_t num_copied = std::min((size_t) params.n_batch+1, state.embd_inp.size() - state.input_consumed);
1200
+ // std::copy(state.embd_inp.begin() + state.input_consumed,
1201
+ // state.embd_inp.begin() + state.input_consumed + num_copied,
1202
+ // std::back_inserter(state.embd));
1203
+ // state.input_consumed += num_copied;
1204
+
1205
+ // // Copy the last `repeat_last_n` elements copied into embd to last_n_tokens
1206
+ // size_t num_copied_last_n = std::min(num_copied, (size_t) params.repeat_last_n);
1207
+ // state.last_n_tokens.erase(state.last_n_tokens.begin(), state.last_n_tokens.begin()+num_copied_last_n);
1208
+ // state.last_n_tokens.insert(state.last_n_tokens.end(), state.embd.end() - num_copied_last_n, state.embd.end());
1194
1209
}
1195
1210
1196
- // / @brief Run the prediction step on ctx.embd and store result in ctx.state.logits
1197
- // / @param ctx
1198
- // / @return
1199
- bool llama_predict (llama_context& ctx){
1200
- const int64_t t_start_us = ggml_time_us ();
1211
+ bool llama_eval_model (llama_context& ctx)
1212
+ {
1201
1213
llama_state& state = *ctx.state ;
1202
1214
llama_model& model = ctx.model ;
1203
1215
const gpt_params& params = ctx.params ;
1204
1216
1205
- if (!llama_eval (model, params.n_threads , state.n_past , state.embd , state.logits , state.mem_per_token )) {
1206
- fprintf (stderr, " Failed to predict\n " );
1207
- return false ;
1208
- }
1217
+ if (state.embd .size () > 0 ) {
1218
+ const int64_t t_start_us = ggml_time_us ();
1209
1219
1210
- state.t_predict_us += ggml_time_us () - t_start_us;
1220
+ if (!llama_eval (model, params.n_threads , state.n_past , state.embd , state.logits , state.mem_per_token )) {
1221
+ fprintf (stderr, " Failed to predict\n " );
1222
+ return false ;
1223
+ }
1224
+ state.timing .t_predict_us += ggml_time_us () - t_start_us;
1225
+ }
1226
+ state.n_past += state.embd .size ();
1227
+ state.embd .clear ();
1211
1228
return true ;
1212
1229
}
1230
+ bool llama_has_unconsumed_input (llama_context& ctx)
1231
+ {
1232
+ llama_state& state = *ctx.state ;
1233
+ return state.input_consumed < state.embd_inp .size ();
1234
+ }
1235
+
1213
1236
// / @brief Sample a token from the logits
1214
1237
// / @param ctx
1215
1238
// / @return token id
@@ -1237,34 +1260,34 @@ gpt_vocab::id llama_sample_token(llama_context& ctx)
1237
1260
state.last_n_tokens .erase (state.last_n_tokens .begin ());
1238
1261
state.last_n_tokens .push_back (id);
1239
1262
1240
- state.t_sample_us += ggml_time_us () - t_start_sample_us;
1263
+ state.timing . t_sample_us += ggml_time_us () - t_start_sample_us;
1241
1264
}
1242
1265
return id;
1243
1266
}
1244
1267
// / @brief Ingest all input (in multiple batches) into model and run call predict()
1245
1268
// / @param ctx
1246
- bool llama_ingest_input (llama_context& ctx, const std::string& text, bool clear_existing )
1269
+ bool llama_ingest_all_pending_input (llama_context& ctx, bool print_tokens )
1247
1270
{
1248
1271
llama_state& state = *ctx.state ;
1272
+ const std::vector<gpt_vocab::id>& embd = state.embd ;
1273
+ gpt_vocab& vocab = ctx.vocab ;
1249
1274
1250
- // Initialize context, tokenize text and clear existing state if necessary
1251
- if (!state.is_initialized && !llama_update_context_with_prompt (ctx, text, clear_existing))
1275
+ if (!state.is_initialized )
1252
1276
{
1277
+ fprintf (stderr, " Context must be initialized before ingesting input" );
1253
1278
return false ;
1254
1279
}
1255
1280
1256
1281
// ingest the tokens into the model one batch at a time
1257
- while (state. has_more_input ( ))
1282
+ while (llama_has_unconsumed_input (ctx ))
1258
1283
{
1259
1284
llama_ingest_input_batch (ctx);
1260
- if (state.embd .size () >= 0 ) {
1261
- if (!llama_predict (ctx))
1262
- {
1263
- return false ;
1264
- };
1285
+ if (print_tokens) {
1286
+ std::string s = llama_tokens_to_string (vocab, embd);
1287
+ printf (" %s" , s.c_str ());
1288
+ fflush (stdout);
1265
1289
}
1266
- state.n_past += state.embd .size ();
1267
- state.embd .clear ();
1290
+ llama_eval_model (ctx);
1268
1291
}
1269
1292
return true ;
1270
1293
}
@@ -1283,25 +1306,45 @@ bool llama_infer(llama_context& ctx, gpt_vocab::id& id) {
1283
1306
return false ;
1284
1307
}
1285
1308
1286
- // Do prediction if we have enough tokens
1287
- if (state.embd .size () > 0 ) {
1288
- if (!llama_predict (ctx))
1289
- {
1290
- return false ;
1291
- }
1292
- }
1293
- // sample a token
1309
+ // Already predicted, so we just need to sample
1310
+ // sample a token
1294
1311
id = llama_sample_token (ctx);
1312
+
1295
1313
// add it to the context
1296
1314
state.embd .push_back (id);
1297
1315
1298
- state.n_past += 1 ;
1299
1316
// decrement remaining sampling budget
1300
1317
--state.remaining_tokens ;
1301
1318
1302
- // end of text token
1303
- if (state.embd .back () == 2 ) {
1304
- state.remaining_tokens = 0 ;
1319
+ return true ;
1320
+ }
1321
+ bool llama_infer (llama_context& ctx, std::string& output, bool & is_end_of_text) {
1322
+ // Call overloaded llama_infer and convert to string before returning
1323
+ gpt_vocab::id id_int;
1324
+ is_end_of_text = false ;
1325
+ if (!llama_infer (ctx, id_int)){
1326
+ return false ;
1305
1327
}
1328
+
1329
+ // Pass through the "end of text" token to the user
1330
+ is_end_of_text = (id_int == 2 );
1331
+
1332
+ // Make sure to pass in the newly generated token to the model as well
1333
+ llama_eval_model (ctx);
1334
+ output = ctx.vocab .id_to_token .at (id_int);
1306
1335
return true ;
1307
1336
}
1337
+ bool llama_add_bos (llama_context& ctx){
1338
+ // Add the "bos" token into the model input
1339
+ llama_state& state = *ctx.state ;
1340
+ llama_model& model = ctx.model ;
1341
+ const gpt_params& params = ctx.params ;
1342
+
1343
+ const gpt_vocab::id bos_token = 1 ;
1344
+ state.embd_inp .push_back (bos_token);
1345
+ }
1346
+ bool llama_is_anti_prompt_present (llama_context& ctx, const std::vector<gpt_vocab::id>& antiprompt_inp)
1347
+ {
1348
+ llama_state& state = *ctx.state ;
1349
+ return std::equal (antiprompt_inp.rbegin (), antiprompt_inp.rend (), state.last_n_tokens .rbegin ());
1350
+ }
0 commit comments