Skip to content

Commit 660a4d5

Browse files
committed
Refactor interactive mode in main.cpp
1 parent 3839a08 commit 660a4d5

File tree

1 file changed

+90
-67
lines changed

1 file changed

+90
-67
lines changed

main.cpp

Lines changed: 90 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
#define ANSI_COLOR_RESET "\x1b[0m"
2929
#define ANSI_BOLD "\x1b[1m"
3030

31-
static const int EOS_TOKEN_ID = 2;
3231

3332
// determine number of model parts based on the dimension
3433
static const std::map<int, int> LLAMA_N_PARTS = {
@@ -56,6 +55,8 @@ void sigint_handler(int signo) {
5655
#endif
5756

5857

58+
void process_interactive_input(llama_context& ctx, const gpt_params& params);
59+
5960
int main(int argc, char ** argv) {
6061
ggml_time_init();
6162
const int64_t t_main_start_us = ggml_time_us();
@@ -86,15 +87,18 @@ int main(int argc, char ** argv) {
8687
// params.prompt = R"(// this function checks if the number n is prime
8788
//bool is_prime(int n) {)";
8889

89-
int64_t t_load_us = 0;
90-
9190
// load the model
92-
llama_context* ctx_ptr = llama_init_from_params(params);
91+
llama_context* ctx_ptr = nullptr;
92+
{
93+
ctx_ptr = llama_init_from_params(params);
94+
if (!ctx_ptr) {
95+
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
96+
return 1;
97+
}
98+
}
99+
93100
llama_context & ctx = *ctx_ptr;
94-
gpt_vocab & vocab = llama_context_get_vocab(ctx);
95-
96-
// print system information
97-
llama_print_context_info(ctx);
101+
const gpt_vocab & vocab = llama_context_get_vocab(ctx);
98102

99103
// Add a space in front of the first character to match OG llama tokenizer behavior
100104
params.prompt.insert(0, 1, ' ');
@@ -110,8 +114,9 @@ int main(int argc, char ** argv) {
110114
}
111115

112116
// tokenize the reverse prompt
113-
std::vector<gpt_vocab::id> antiprompt_inp = llama_tokenize_text(ctx, params.prompt);
117+
std::vector<gpt_vocab::id> antiprompt_inp = llama_tokenize_text(ctx, params.antiprompt);
114118

119+
// Setup interactive mode
115120
if (params.interactive) {
116121
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
117122
struct sigaction sigint_action;
@@ -150,43 +155,43 @@ int main(int argc, char ** argv) {
150155
is_interacting = true;
151156
}
152157

153-
bool input_noecho = false;
154-
155-
int remaining_tokens = params.n_predict;
156-
157158
// set the color for the prompt which will be output initially
158159
if (params.use_color) {
159160
printf(ANSI_COLOR_YELLOW);
160161
}
161162

162-
if(!llama_ingest_input(ctx, params.prompt))
163+
// Prepare the context with input
164+
// Send "beginning of string"
165+
llama_add_bos(ctx);
166+
167+
// load the input
168+
llama_update_input(ctx, params.prompt);
169+
170+
llama_print_startup_stats(ctx);
171+
172+
if(!llama_prepare_context(ctx))
163173
{
164-
fprintf(stderr, "Failed to ingest prompt\n");
174+
fprintf(stderr, "%s: failed to prepare context\n", __func__);
165175
return 1;
166-
};
167-
168-
// display text
169-
input_noecho = false;
170-
const std::vector<gpt_vocab::id>& embd = llama_context_get_embedding(ctx);
171-
if (!input_noecho) {
172-
for (auto id : embd) {
173-
printf("%s", vocab.id_to_token[id].c_str());
174-
}
175-
fflush(stdout);
176176
}
177177

178-
if (!input_noecho && params.use_color) {
179-
printf(ANSI_COLOR_RESET);
180-
}
181-
182-
const std::vector<gpt_vocab::id>& last_n_tokens = llama_context_get_last_n_tokens(ctx);
183-
184-
while (llama_context_is_finished(ctx) != true) {
185-
gpt_vocab::id model_output = 0;
186-
bool response = llama_infer(ctx, model_output);
187-
if (response) {
188-
printf("%s", vocab.id_to_token[model_output].c_str());
189-
fflush(stdout);
178+
bool input_noecho = false;
179+
bool is_end_of_text = false;
180+
while (llama_context_is_finished(ctx) == false) {
181+
std::string model_output{};
182+
183+
if (llama_has_unconsumed_input(ctx)) {
184+
llama_ingest_all_pending_input(ctx, !input_noecho);
185+
// reset color to default if we there is no pending user input
186+
if (!input_noecho && params.use_color) {
187+
printf(ANSI_COLOR_RESET);
188+
}
189+
}else{
190+
// Run inference if we don't have any pending input
191+
llama_infer(ctx, model_output, is_end_of_text);
192+
// print the single token output
193+
printf("%s", model_output.c_str());
194+
input_noecho = false;
190195
}
191196
// reset color to default if we there is no pending user input
192197
if (!input_noecho && params.use_color && (int)embd_inp.size() == input_consumed) {
@@ -195,48 +200,31 @@ int main(int argc, char ** argv) {
195200

196201
// in interactive mode, and not currently processing queued inputs;
197202
// check if we should prompt the user for more
198-
if (params.interactive) {
199-
// check for reverse prompt
200-
for (auto antiprompt_inp : antipromptv_inp) {
201-
if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) {
203+
if (params.interactive && !llama_has_unconsumed_input(ctx)) {
204+
// check for reverse prompt
205+
if (antiprompt_inp.size() && llama_is_anti_prompt_present(ctx, antiprompt_inp)) {
202206
// reverse prompt found
203207
is_interacting = true;
204208
break;
205209
}
206210
}
207211
if (is_interacting) {
208212
if (params.instruct) {
209-
input_consumed = embd_inp.size();
210-
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
213+
llama_update_input(ctx, "\n\n### Instruction:\n\n");
211214

212215
printf("\n> ");
213216
}
214217

215218
// currently being interactive
216-
if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN);
217-
std::string buffer;
218-
std::string line;
219-
bool another_line = true;
220-
do {
221-
std::getline(std::cin, line);
222-
if (line.empty() || line.back() != '\\') {
223-
another_line = false;
224-
} else {
225-
line.pop_back(); // Remove the continue character
226-
}
227-
// Do not clear existing context in interactive mode
228-
llama_update_context_with_prompt(ctx, buf, false);
229-
}
230-
231-
remaining_tokens -= line_inp.size();
232-
233-
input_noecho = true; // do not echo this again
219+
process_interactive_input(ctx, params);
220+
input_noecho = true; // do not echo this input again
221+
is_interacting = false;
234222
}
235223
is_interacting = false;
236224
}
237225

238226
// end of text token
239-
if (embd.back() == EOS_TOKEN_ID) {
227+
if (is_end_of_text) {
240228
if (params.interactive) {
241229
is_interacting = true;
242230
} else {
@@ -246,23 +234,58 @@ int main(int argc, char ** argv) {
246234
}
247235

248236
// In interactive mode, respect the maximum number of tokens and drop back to user input when reached.
249-
if (params.interactive && remaining_tokens <= 0) {
250-
remaining_tokens = params.n_predict;
237+
if (params.interactive && llama_context_is_finished(ctx)) {
238+
llama_reset_remaining_tokens(ctx);
251239
is_interacting = true;
252240
}
253241
}
254242

255-
// report timing from context
243+
244+
#if defined (_WIN32)
245+
signal(SIGINT, SIG_DFL);
246+
#endif
247+
248+
// report timing
256249
{
257250
const int64_t t_main_end_us = ggml_time_us();
258251
llama_print_end_stats(ctx);
259252
fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
260253
}
261-
llama_free_context(ctx_ptr);
254+
255+
llama_free_context(ctx_ptr);
262256

263257
if (params.use_color) {
264258
printf(ANSI_COLOR_RESET);
265259
}
266-
267260
return 0;
268261
}
262+
263+
void process_interactive_input(llama_context& ctx, const gpt_params& params)
264+
{
265+
bool another_line = true;
266+
while (another_line) {
267+
fflush(stdout);
268+
char buf[256] = {0};
269+
int n_read;
270+
if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN);
271+
if (scanf("%255[^\n]%n%*c", buf, &n_read) <= 0) {
272+
// presumable empty line, consume the newline
273+
std::ignore = scanf("%*c");
274+
n_read=0;
275+
}
276+
if (params.use_color) printf(ANSI_COLOR_RESET);
277+
278+
if (n_read > 0 && buf[n_read-1]=='\\') {
279+
another_line = true;
280+
buf[n_read-1] = '\n';
281+
buf[n_read] = 0;
282+
} else {
283+
another_line = false;
284+
buf[n_read] = '\n';
285+
buf[n_read+1] = 0;
286+
}
287+
288+
// Do not clear existing context in interactive mode
289+
llama_update_input(ctx, buf);
290+
}
291+
}

0 commit comments

Comments
 (0)