Skip to content

Commit 292bf64

Browse files
committed
Refactor llama.cpp and llama.h
1 parent 36639c3 commit 292bf64

File tree

2 files changed

+146
-83
lines changed

2 files changed

+146
-83
lines changed

llama.cpp

Lines changed: 122 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -72,33 +72,36 @@ struct llama_model {
7272
};
7373
struct llama_state
7474
{
75-
int64_t t_sample_us = 0;
76-
int64_t t_predict_us = 0;
75+
// Timers
76+
struct timing {
77+
int64_t t_load_us = 0;
7778

78-
std::vector<float> logits;
79+
int64_t t_sample_us = 0;
80+
int64_t t_predict_us = 0;
81+
} timing;
7982

80-
mutable std::mt19937 rng;
83+
// Random number generator
84+
std::mt19937 rng{};
8185

86+
// Tokens
8287
std::vector<gpt_vocab::id> embd{};
88+
std::vector<gpt_vocab::id> embd_inp{};
89+
std::vector<gpt_vocab::id> last_n_tokens{};
8390

91+
// Logits from inference
92+
std::vector<float> logits{};
93+
94+
// Counters
8495
int input_consumed = 0;
85-
std::vector<gpt_vocab::id> embd_inp;
86-
std::vector<gpt_vocab::id> last_n_tokens;
8796
int remaining_tokens = 0;
8897
int n_past = 0;
8998
size_t mem_per_token = 0;
90-
bool is_initialized = false;
91-
llama_state() {}
9299

93-
bool has_more_input() const {
94-
return input_consumed < embd_inp.size();
95-
}
100+
// Flag set after initialization
101+
bool is_initialized = false;
96102
};
97103
struct llama_context
98104
{
99-
int64_t t_load_us = 0;
100-
int64_t t_start_us = 0;
101-
102105
ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 or FP16)
103106

104107
llama_model model{};
@@ -111,8 +114,6 @@ struct llama_context
111114
llama_context() = default;
112115
// constructor
113116
llama_context(llama_model&& model, gpt_vocab&& vocab, const gpt_params& params):
114-
t_load_us(0),
115-
t_start_us(0),
116117
wtype(ggml_type::GGML_TYPE_F16),
117118
model(std::move(model)),
118119
vocab(std::move(vocab)),
@@ -829,7 +830,8 @@ bool llama_context_is_finished(const llama_context& ctx)
829830
return ctx.state->remaining_tokens <= 0;
830831
}
831832
const std::vector<gpt_vocab::id> llama_tokenize_text(const llama_context& ctx, const std::string& text) {
832-
return llama_tokenize(ctx.vocab, text, true);
833+
// Make sure that the "beginning of string" token is not prefixed to the text
834+
return llama_tokenize(ctx.vocab, text, false);
833835
}
834836
const std::vector<gpt_vocab::id>& llama_context_get_last_n_tokens(const llama_context& ctx) {
835837
return ctx.state->last_n_tokens;
@@ -847,7 +849,8 @@ llama_context* llama_init_from_params(const gpt_params& params) {
847849
return nullptr;
848850
}
849851
llama_context* ctx = new llama_context(std::move(model), std::move(vocab), params);
850-
ctx->t_load_us = t_end - t_start;
852+
ctx->state->timing.t_load_us = t_end - t_start;
853+
ctx->state->rng = std::mt19937(params.seed);
851854
return ctx;
852855
}
853856
void llama_free_context(llama_context* ctx) {
@@ -874,7 +877,7 @@ const char * llama_print_system_info(void) {
874877
return s.c_str();
875878
}
876879

877-
void llama_print_context_info(const llama_context& ctx)
880+
void llama_print_startup_stats(const llama_context& ctx)
878881
{
879882
const gpt_params& params = ctx.params;
880883
const std::vector<gpt_vocab::id>& embd_inp = ctx.state->embd_inp;
@@ -897,9 +900,9 @@ void llama_print_end_stats(const llama_context& ctx)
897900
const llama_state& state = *ctx.state;
898901
fprintf(stderr, "\n\n");
899902
fprintf(stderr, "%s: mem per token = %8zu bytes\n", __func__, state.mem_per_token);
900-
fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx.t_load_us/1000.0f);
901-
fprintf(stderr, "%s: sample time = %8.2f ms\n", __func__, state.t_sample_us/1000.0f);
902-
fprintf(stderr, "%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, state.t_predict_us/1000.0f, state.t_predict_us/1000.0f/state.n_past);
903+
fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, state.timing.t_load_us/1000.0f);
904+
fprintf(stderr, "%s: sample time = %8.2f ms\n", __func__, state.timing.t_sample_us/1000.0f);
905+
fprintf(stderr, "%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, state.timing.t_predict_us/1000.0f, state.timing.t_predict_us/1000.0f/state.n_past);
903906
}
904907
// evaluate the transformer
905908
//
@@ -1137,25 +1140,26 @@ bool llama_eval(
11371140
return true;
11381141
}
11391142

1140-
bool llama_update_context_with_prompt(llama_context& ctx, const std::string& text, bool clear_existing) {
1143+
void llama_update_input(llama_context& ctx, const std::string& text)
1144+
{
11411145
llama_state& state = *ctx.state;
11421146
llama_model& model = ctx.model;
11431147
const gpt_params& params = ctx.params;
11441148

1145-
if (clear_existing) {
1146-
state.embd.clear();
1147-
state.input_consumed = 0;
1148-
state.embd_inp.clear();
1149-
state.last_n_tokens.clear();
1150-
state.remaining_tokens = 0;
1151-
state.n_past = 0;
1152-
}
1153-
11541149
std::vector<gpt_vocab::id> line_inp = llama_tokenize_text(ctx, text);
1150+
11551151
state.embd_inp.insert(state.embd_inp.end(), line_inp.begin(), line_inp.end());
1152+
state.remaining_tokens -= line_inp.size();
1153+
}
1154+
1155+
bool llama_prepare_context(llama_context& ctx)
1156+
{
1157+
llama_state& state = *ctx.state;
1158+
llama_model& model = ctx.model;
1159+
gpt_params& params = ctx.params;
11561160

11571161
int n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) state.embd_inp.size());
1158-
state.remaining_tokens = n_predict;
1162+
params.n_predict = n_predict;
11591163

11601164
// determine the required inference memory per token:
11611165
state.mem_per_token = 0;
@@ -1168,8 +1172,9 @@ bool llama_update_context_with_prompt(llama_context& ctx, const std::string& tex
11681172
int last_n_size = params.repeat_last_n;
11691173
state.last_n_tokens = std::vector<gpt_vocab::id>(last_n_size);
11701174
std::fill(state.last_n_tokens.begin(), state.last_n_tokens.end(), 0);
1171-
11721175
state.is_initialized = true;
1176+
state.remaining_tokens = params.n_predict;
1177+
state.input_consumed = 0;
11731178
return true;
11741179
}
11751180

@@ -1180,36 +1185,54 @@ void llama_ingest_input_batch(llama_context& ctx)
11801185
llama_state& state = *ctx.state;
11811186
const gpt_params& params = ctx.params;
11821187

1183-
// Copy at most n_batch elements from embd_inp to embd
1184-
size_t num_copied = std::min((size_t) params.n_batch, state.embd_inp.size() - state.input_consumed);
1185-
std::copy(state.embd_inp.begin() + state.input_consumed,
1186-
state.embd_inp.begin() + state.input_consumed + num_copied,
1187-
std::back_inserter(state.embd));
1188-
state.input_consumed += num_copied;
1189-
1190-
// Copy the last `repeat_last_n` elements copied into embd to last_n_tokens
1191-
size_t num_copied_last_n = std::min(num_copied, (size_t) params.repeat_last_n);
1192-
state.last_n_tokens.erase(state.last_n_tokens.begin(), state.last_n_tokens.begin()+num_copied_last_n);
1193-
state.last_n_tokens.insert(state.last_n_tokens.end(), state.embd.end() - num_copied_last_n, state.embd.end());
1188+
// some user input remains from prompt or interaction, forward it to processing
1189+
while (state.embd_inp.size() > state.input_consumed) {
1190+
state.embd.push_back(state.embd_inp[state.input_consumed]);
1191+
state.last_n_tokens.erase(state.last_n_tokens.begin());
1192+
state.last_n_tokens.push_back(state.embd_inp[state.input_consumed]);
1193+
++state.input_consumed;
1194+
if (state.embd.size() > params.n_batch) {
1195+
break;
1196+
}
1197+
}
1198+
// // Copy at most n_batch elements from embd_inp to embd
1199+
// size_t num_copied = std::min((size_t) params.n_batch+1, state.embd_inp.size() - state.input_consumed);
1200+
// std::copy(state.embd_inp.begin() + state.input_consumed,
1201+
// state.embd_inp.begin() + state.input_consumed + num_copied,
1202+
// std::back_inserter(state.embd));
1203+
// state.input_consumed += num_copied;
1204+
1205+
// // Copy the last `repeat_last_n` elements copied into embd to last_n_tokens
1206+
// size_t num_copied_last_n = std::min(num_copied, (size_t) params.repeat_last_n);
1207+
// state.last_n_tokens.erase(state.last_n_tokens.begin(), state.last_n_tokens.begin()+num_copied_last_n);
1208+
// state.last_n_tokens.insert(state.last_n_tokens.end(), state.embd.end() - num_copied_last_n, state.embd.end());
11941209
}
11951210

1196-
/// @brief Run the prediction step on ctx.embd and store result in ctx.state.logits
1197-
/// @param ctx
1198-
/// @return
1199-
bool llama_predict(llama_context& ctx){
1200-
const int64_t t_start_us = ggml_time_us();
1211+
bool llama_eval_model(llama_context& ctx)
1212+
{
12011213
llama_state& state = *ctx.state;
12021214
llama_model& model = ctx.model;
12031215
const gpt_params& params = ctx.params;
12041216

1205-
if (!llama_eval(model, params.n_threads, state.n_past, state.embd, state.logits, state.mem_per_token)) {
1206-
fprintf(stderr, "Failed to predict\n");
1207-
return false;
1208-
}
1217+
if (state.embd.size() > 0) {
1218+
const int64_t t_start_us = ggml_time_us();
12091219

1210-
state.t_predict_us += ggml_time_us() - t_start_us;
1220+
if (!llama_eval(model, params.n_threads, state.n_past, state.embd, state.logits, state.mem_per_token)) {
1221+
fprintf(stderr, "Failed to predict\n");
1222+
return false;
1223+
}
1224+
state.timing.t_predict_us += ggml_time_us() - t_start_us;
1225+
}
1226+
state.n_past += state.embd.size();
1227+
state.embd.clear();
12111228
return true;
12121229
}
1230+
bool llama_has_unconsumed_input(llama_context& ctx)
1231+
{
1232+
llama_state& state = *ctx.state;
1233+
return state.input_consumed < state.embd_inp.size();
1234+
}
1235+
12131236
/// @brief Sample a token from the logits
12141237
/// @param ctx
12151238
/// @return token id
@@ -1237,34 +1260,34 @@ gpt_vocab::id llama_sample_token(llama_context& ctx)
12371260
state.last_n_tokens.erase(state.last_n_tokens.begin());
12381261
state.last_n_tokens.push_back(id);
12391262

1240-
state.t_sample_us += ggml_time_us() - t_start_sample_us;
1263+
state.timing.t_sample_us += ggml_time_us() - t_start_sample_us;
12411264
}
12421265
return id;
12431266
}
12441267
/// @brief Ingest all input (in multiple batches) into model and run call predict()
12451268
/// @param ctx
1246-
bool llama_ingest_input(llama_context& ctx, const std::string& text, bool clear_existing)
1269+
bool llama_ingest_all_pending_input(llama_context& ctx, bool print_tokens)
12471270
{
12481271
llama_state& state = *ctx.state;
1272+
const std::vector<gpt_vocab::id>& embd = state.embd;
1273+
gpt_vocab& vocab = ctx.vocab;
12491274

1250-
// Initialize context, tokenize text and clear existing state if necessary
1251-
if(!state.is_initialized && !llama_update_context_with_prompt(ctx, text, clear_existing))
1275+
if(!state.is_initialized)
12521276
{
1277+
fprintf(stderr, "Context must be initialized before ingesting input");
12531278
return false;
12541279
}
12551280

12561281
// ingest the tokens into the model one batch at a time
1257-
while (state.has_more_input())
1282+
while (llama_has_unconsumed_input(ctx))
12581283
{
12591284
llama_ingest_input_batch(ctx);
1260-
if (state.embd.size() >= 0) {
1261-
if(!llama_predict(ctx))
1262-
{
1263-
return false;
1264-
};
1285+
if (print_tokens) {
1286+
std::string s = llama_tokens_to_string(vocab, embd);
1287+
printf("%s", s.c_str());
1288+
fflush(stdout);
12651289
}
1266-
state.n_past += state.embd.size();
1267-
state.embd.clear();
1290+
llama_eval_model(ctx);
12681291
}
12691292
return true;
12701293
}
@@ -1283,25 +1306,45 @@ bool llama_infer(llama_context& ctx, gpt_vocab::id& id) {
12831306
return false;
12841307
}
12851308

1286-
// Do prediction if we have enough tokens
1287-
if (state.embd.size() > 0) {
1288-
if(!llama_predict(ctx))
1289-
{
1290-
return false;
1291-
}
1292-
}
1293-
// sample a token
1309+
// Already predicted, so we just need to sample
1310+
// sample a token
12941311
id = llama_sample_token(ctx);
1312+
12951313
// add it to the context
12961314
state.embd.push_back(id);
12971315

1298-
state.n_past += 1;
12991316
// decrement remaining sampling budget
13001317
--state.remaining_tokens;
13011318

1302-
// end of text token
1303-
if (state.embd.back() == 2) {
1304-
state.remaining_tokens = 0;
1319+
return true;
1320+
}
1321+
bool llama_infer(llama_context& ctx, std::string& output, bool& is_end_of_text) {
1322+
// Call overloaded llama_infer and convert to string before returning
1323+
gpt_vocab::id id_int;
1324+
is_end_of_text = false;
1325+
if(!llama_infer(ctx, id_int)){
1326+
return false;
13051327
}
1328+
1329+
// Pass through the "end of text" token to the user
1330+
is_end_of_text = (id_int == 2);
1331+
1332+
// Make sure to pass in the newly generated token to the model as well
1333+
llama_eval_model(ctx);
1334+
output = ctx.vocab.id_to_token.at(id_int);
13061335
return true;
13071336
}
1337+
bool llama_add_bos(llama_context& ctx){
1338+
// Add the "bos" token into the model input
1339+
llama_state& state = *ctx.state;
1340+
llama_model& model = ctx.model;
1341+
const gpt_params& params = ctx.params;
1342+
1343+
const gpt_vocab::id bos_token = 1;
1344+
state.embd_inp.push_back(bos_token);
1345+
}
1346+
bool llama_is_anti_prompt_present(llama_context& ctx, const std::vector<gpt_vocab::id>& antiprompt_inp)
1347+
{
1348+
llama_state& state = *ctx.state;
1349+
return std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), state.last_n_tokens.rbegin());
1350+
}

llama.h

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,33 @@ struct llama_context;
4141

4242
// Startup
4343
llama_context* llama_init_from_params(const gpt_params& params);
44+
bool llama_prepare_context(llama_context& ctx);
4445

4546
// Input processing and inference
46-
bool llama_ingest_input(llama_context& ctx, const std::string& text, bool clear_existing = true);
47-
bool llama_context_is_finished(const llama_context& ctx);
48-
bool llama_update_context_with_prompt(llama_context& ctx, const std::string& text, bool clear_existing = true);
47+
// Tokenize text (never adds BOS)
4948
const std::vector<gpt_vocab::id> llama_tokenize_text(const llama_context& ctx, const std::string& text);
49+
// Queues up a BOS token to the model input
50+
bool llama_add_bos(llama_context& ctx);
51+
// Queues up input text to the model input
52+
void llama_update_input(llama_context& ctx, const std::string& text);
53+
// Ingests input previously added using llama_update_input()
54+
void llama_ingest_input_batch(llama_context& ctx);
55+
// Ingests all input previously added using llama_update_input() in multiple batches
56+
// Batch size is determined by gpt_params::n_predict
57+
bool llama_ingest_all_pending_input(llama_context& ctx, bool print_tokens = false);
58+
// Checks if the model has unconsumed input to be ingested using llama_ingest_input_batch()
59+
bool llama_has_unconsumed_input(llama_context& ctx);
60+
// Checks if the model has an anti-prompt present its most recent output
61+
bool llama_is_anti_prompt_present(llama_context& ctx, const std::vector<gpt_vocab::id>& antiprompt_inp);
62+
63+
// Evaluate the model on a batch of input. Must call llama_ingest_input_batch() first.
64+
bool llama_eval_model(llama_context& ctx);
65+
// Checks if the model has finished generating output (i.e. has generated an EOS token or remaining_tokens == 0)
66+
bool llama_context_is_finished(const llama_context& ctx);
67+
68+
// Overloaded functions to run inference and return either the model output or the decoded text
5069
bool llama_infer(llama_context& ctx, gpt_vocab::id& model_output);
70+
bool llama_infer(llama_context& ctx, std::string& output, bool& is_end_of_text);
5171

5272
// Teardown
5373
void llama_free_context(llama_context* ctx);
@@ -61,5 +81,5 @@ const std::vector<gpt_vocab::id>& llama_context_get_last_n_tokens(const llama_co
6181
bool llama_model_quantize(const std::string & fname_inp, const std::string & fname_out, int itype);
6282

6383
// Stats
64-
void llama_print_context_info(const llama_context& ctx);
84+
void llama_print_startup_stats(const llama_context& ctx);
6585
void llama_print_end_stats(const llama_context& ctx);

0 commit comments

Comments
 (0)