@@ -70,6 +70,7 @@ struct mtmd_cli_context {
70
70
llama_model * model;
71
71
llama_context * lctx;
72
72
const llama_vocab * vocab;
73
+ common_sampler * smpl;
73
74
llama_batch batch;
74
75
int n_batch;
75
76
@@ -89,8 +90,9 @@ struct mtmd_cli_context {
89
90
model = llama_init.model .get ();
90
91
lctx = llama_init.context .get ();
91
92
vocab = llama_model_get_vocab (model);
93
+ smpl = common_sampler_init (model, params.sampling );
92
94
n_threads = params.cpuparams .n_threads ;
93
- batch = llama_batch_init (params. n_batch , 0 , 1 );
95
+ batch = llama_batch_init (1 , 0 , 1 ); // batch for next token generation
94
96
n_batch = params.n_batch ;
95
97
96
98
if (!model || !lctx) {
@@ -118,6 +120,11 @@ struct mtmd_cli_context {
118
120
}
119
121
}
120
122
123
+ ~mtmd_cli_context () {
124
+ llama_batch_free (batch);
125
+ common_sampler_free (smpl);
126
+ }
127
+
121
128
void init_vision_context (common_params & params) {
122
129
const char * clip_path = params.mmproj .path .c_str ();
123
130
mtmd_context_params mparams = mtmd_context_params_default ();
@@ -153,17 +160,17 @@ struct mtmd_cli_context {
153
160
}
154
161
};
155
162
156
- static int generate_response (mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
163
+ static int generate_response (mtmd_cli_context & ctx, int n_predict) {
157
164
llama_tokens generated_tokens;
158
165
for (int i = 0 ; i < n_predict; i++) {
159
166
if (i > n_predict || !g_is_generating || g_is_interrupted) {
160
167
LOG (" \n " );
161
168
break ;
162
169
}
163
170
164
- llama_token token_id = common_sampler_sample (smpl, ctx.lctx , -1 );
171
+ llama_token token_id = common_sampler_sample (ctx. smpl , ctx.lctx , -1 );
165
172
generated_tokens.push_back (token_id);
166
- common_sampler_accept (smpl, token_id, true );
173
+ common_sampler_accept (ctx. smpl , token_id, true );
167
174
168
175
if (llama_vocab_is_eog (ctx.vocab , token_id) || ctx.check_antiprompt (generated_tokens)) {
169
176
LOG (" \n " );
@@ -261,7 +268,6 @@ int main(int argc, char ** argv) {
261
268
262
269
bool is_single_turn = !params.prompt .empty () && !params.image .empty ();
263
270
264
- struct common_sampler * smpl = common_sampler_init (ctx.model , params.sampling );
265
271
int n_predict = params.n_predict < 0 ? INT_MAX : params.n_predict ;
266
272
267
273
// Ctrl+C handling
@@ -300,7 +306,7 @@ int main(int argc, char ** argv) {
300
306
if (eval_message (ctx, msg, true )) {
301
307
return 1 ;
302
308
}
303
- if (!g_is_interrupted && generate_response (ctx, smpl, n_predict)) {
309
+ if (!g_is_interrupted && generate_response (ctx, n_predict)) {
304
310
return 1 ;
305
311
}
306
312
@@ -366,7 +372,7 @@ int main(int argc, char ** argv) {
366
372
return 1 ;
367
373
}
368
374
if (g_is_interrupted) break ;
369
- if (generate_response (ctx, smpl, n_predict)) {
375
+ if (generate_response (ctx, n_predict)) {
370
376
return 1 ;
371
377
}
372
378
content.clear ();
0 commit comments