|
6 | 6 | #include "gguf-llama.h"
|
7 | 7 | #include "build-info.h"
|
8 | 8 |
|
9 |
| -#include <cassert> |
10 |
| -#include <cinttypes> |
11 | 9 | #include <cmath>
|
12 | 10 | #include <cstdio>
|
13 |
| -#include <cstring> |
14 |
| -#include <ctime> |
15 |
| -#include <fstream> |
16 |
| -#include <iostream> |
17 | 11 | #include <string>
|
18 | 12 | #include <vector>
|
19 | 13 |
|
20 |
| -#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) |
21 |
| -#include <signal.h> |
22 |
| -#include <unistd.h> |
23 |
| -#elif defined (_WIN32) |
24 |
| -#define WIN32_LEAN_AND_MEAN |
25 |
| -#define NOMINMAX |
26 |
| -#include <windows.h> |
27 |
| -#include <signal.h> |
28 |
| -#endif |
29 |
| - |
30 |
| - |
31 |
| - |
32 |
| -int main(int argc, char ** argv) |
33 |
| -{ |
| 14 | +int main(int argc, char ** argv) { |
34 | 15 | gpt_params params;
|
35 | 16 |
|
36 |
| - //--------------------------------- |
37 |
| - // Print help : |
38 |
| - //--------------------------------- |
39 |
| - |
40 |
| - if ( argc == 1 || argv[1][0] == '-' ) |
41 |
| - { |
42 |
| - printf( "usage: %s MODEL_PATH [PROMPT]\n" , argv[0] ); |
| 17 | + if (argc == 1 || argv[1][0] == '-') { |
| 18 | + printf("usage: %s MODEL_PATH [PROMPT]\n" , argv[0]); |
43 | 19 | return 1 ;
|
44 | 20 | }
|
45 | 21 |
|
46 |
| - //--------------------------------- |
47 |
| - // Load parameters : |
48 |
| - //--------------------------------- |
49 |
| - |
50 |
| - if ( argc >= 2 ) |
51 |
| - { |
| 22 | + if (argc >= 2) { |
52 | 23 | params.model = argv[1];
|
53 | 24 | }
|
54 | 25 |
|
55 |
| - if ( argc >= 3 ) |
56 |
| - { |
| 26 | + if (argc >= 3) { |
57 | 27 | params.prompt = argv[2];
|
58 | 28 | }
|
59 | 29 |
|
60 |
| - if ( params.prompt.empty() ) |
61 |
| - { |
| 30 | + if (params.prompt.empty()) { |
62 | 31 | params.prompt = "Hello my name is";
|
63 | 32 | }
|
64 | 33 |
|
65 |
| - //--------------------------------- |
66 |
| - // Init LLM : |
67 |
| - //--------------------------------- |
| 34 | + // init LLM |
68 | 35 |
|
69 | 36 | llama_backend_init(params.numa);
|
70 | 37 |
|
71 | 38 | llama_context_params ctx_params = llama_context_default_params();
|
72 | 39 |
|
73 | 40 | llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params);
|
74 |
| - |
75 |
| - if ( model == NULL ) |
76 |
| - { |
77 |
| - fprintf( stderr , "%s: error: unable to load model\n" , __func__ ); |
| 41 | + |
| 42 | + if (model == NULL) { |
| 43 | + fprintf(stderr , "%s: error: unable to load model\n" , __func__); |
78 | 44 | return 1;
|
79 | 45 | }
|
80 | 46 |
|
81 | 47 | llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
82 | 48 |
|
83 |
| - //--------------------------------- |
84 |
| - // Tokenize the prompt : |
85 |
| - //--------------------------------- |
| 49 | + // tokenize the prompt |
86 | 50 |
|
87 | 51 | std::vector<llama_token> tokens_list;
|
88 |
| - tokens_list = ::llama_tokenize( ctx , params.prompt , true ); |
| 52 | + tokens_list = ::llama_tokenize(ctx, params.prompt, true); |
89 | 53 |
|
90 |
| - const int max_context_size = llama_n_ctx( ctx ); |
91 |
| - const int max_tokens_list_size = max_context_size - 4 ; |
| 54 | + const int max_context_size = llama_n_ctx(ctx); |
| 55 | + const int max_tokens_list_size = max_context_size - 4; |
92 | 56 |
|
93 |
| - if ( (int)tokens_list.size() > max_tokens_list_size ) |
94 |
| - { |
95 |
| - fprintf( stderr , "%s: error: prompt too long (%d tokens, max %d)\n" , |
96 |
| - __func__ , (int)tokens_list.size() , max_tokens_list_size ); |
| 57 | + if ((int)tokens_list.size() > max_tokens_list_size) { |
| 58 | + fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) tokens_list.size(), max_tokens_list_size); |
97 | 59 | return 1;
|
98 | 60 | }
|
99 | 61 |
|
100 |
| - fprintf( stderr, "\n\n" ); |
101 |
| - |
102 |
| - // Print the tokens from the prompt : |
| 62 | + fprintf(stderr, "\n\n"); |
103 | 63 |
|
104 |
| - for( auto id : tokens_list ) |
105 |
| - { |
106 |
| - printf( "%s" , llama_token_to_str( ctx , id ) ); |
| 64 | + for (auto id : tokens_list) { |
| 65 | + fprintf(stderr, "%s", llama_token_to_str(ctx, id)); |
107 | 66 | }
|
108 | 67 |
|
109 |
| - fflush(stdout); |
110 |
| - |
| 68 | + fflush(stderr); |
111 | 69 |
|
112 |
| - //--------------------------------- |
113 |
| - // Main prediction loop : |
114 |
| - //--------------------------------- |
| 70 | + // main loop |
115 | 71 |
|
116 | 72 | // The LLM keeps a contextual cache memory of previous token evaluation.
|
117 | 73 | // Usually, once this cache is full, it is required to recompute a compressed context based on previous
|
118 | 74 | // tokens (see "infinite text generation via context swapping" in the main example), but in this minimalist
|
119 | 75 | // example, we will just stop the loop once this cache is full or once an end of stream is detected.
|
120 | 76 |
|
121 |
| - while ( llama_get_kv_cache_token_count( ctx ) < max_context_size ) |
122 |
| - { |
123 |
| - //--------------------------------- |
124 |
| - // Evaluate the tokens : |
125 |
| - //--------------------------------- |
| 77 | + while (llama_get_kv_cache_token_count(ctx) < max_context_size) { |
| 78 | + // evaluate the transformer |
126 | 79 |
|
127 |
| - if ( llama_eval( ctx , tokens_list.data() , int(tokens_list.size()) , llama_get_kv_cache_token_count( ctx ) , params.n_threads ) ) |
128 |
| - { |
129 |
| - fprintf( stderr, "%s : failed to eval\n" , __func__ ); |
| 80 | + if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), llama_get_kv_cache_token_count(ctx), params.n_threads)) { |
| 81 | + fprintf(stderr, "%s : failed to eval\n", __func__); |
130 | 82 | return 1;
|
131 | 83 | }
|
132 | 84 |
|
133 | 85 | tokens_list.clear();
|
134 | 86 |
|
135 |
| - //--------------------------------- |
136 |
| - // Select the best prediction : |
137 |
| - //--------------------------------- |
| 87 | + // sample the next token |
138 | 88 |
|
139 | 89 | llama_token new_token_id = 0;
|
140 | 90 |
|
141 |
| - auto logits = llama_get_logits( ctx ); |
142 |
| - auto n_vocab = llama_n_vocab( ctx ); // the size of the LLM vocabulary (in tokens) |
| 91 | + auto logits = llama_get_logits(ctx); |
| 92 | + auto n_vocab = llama_n_vocab(ctx); |
143 | 93 |
|
144 | 94 | std::vector<llama_token_data> candidates;
|
145 |
| - candidates.reserve( n_vocab ); |
| 95 | + candidates.reserve(n_vocab); |
146 | 96 |
|
147 |
| - for( llama_token token_id = 0 ; token_id < n_vocab ; token_id++ ) |
148 |
| - { |
149 |
| - candidates.emplace_back( llama_token_data{ token_id , logits[ token_id ] , 0.0f } ); |
| 97 | + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { |
| 98 | + candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); |
150 | 99 | }
|
151 | 100 |
|
152 | 101 | llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
153 | 102 |
|
154 |
| - // Select it using the "Greedy sampling" method : |
155 |
| - new_token_id = llama_sample_token_greedy( ctx , &candidates_p ); |
156 |
| - |
| 103 | + new_token_id = llama_sample_token_greedy(ctx , &candidates_p); |
157 | 104 |
|
158 | 105 | // is it an end of stream ?
|
159 |
| - if ( new_token_id == llama_token_eos() ) |
160 |
| - { |
| 106 | + if (new_token_id == llama_token_eos()) { |
161 | 107 | fprintf(stderr, " [end of text]\n");
|
162 | 108 | break;
|
163 | 109 | }
|
164 | 110 |
|
165 |
| - // Print the new token : |
166 |
| - printf( "%s" , llama_token_to_str( ctx , new_token_id ) ); |
167 |
| - fflush( stdout ); |
| 111 | + // print the new token : |
| 112 | + printf("%s", llama_token_to_str(ctx, new_token_id)); |
| 113 | + fflush(stdout); |
168 | 114 |
|
169 |
| - // Push this new token for next evaluation : |
170 |
| - tokens_list.push_back( new_token_id ); |
| 115 | + // push this new token for next evaluation |
| 116 | + tokens_list.push_back(new_token_id); |
171 | 117 |
|
172 |
| - } // wend of main loop |
| 118 | + } |
173 | 119 |
|
174 |
| - llama_free( ctx ); |
175 |
| - llama_free_model( model ); |
| 120 | + llama_free(ctx); |
| 121 | + llama_free_model(model); |
176 | 122 |
|
177 | 123 | llama_backend_free();
|
178 | 124 |
|
179 | 125 | return 0;
|
180 | 126 | }
|
181 |
| - |
182 |
| -// EOF |
0 commit comments