@@ -72,6 +72,8 @@ struct mtmd_cli_context {
72
72
llama_batch batch;
73
73
int n_batch;
74
74
75
+ std::vector<mtmd_bitmap> bitmaps;
76
+
75
77
// note: we know that gemma3 template is "linear", meaning each turn is completely separated to another
76
78
// so here we don't need to keep track of chat history
77
79
common_chat_templates_ptr tmpls;
@@ -135,13 +137,22 @@ struct mtmd_cli_context {
135
137
antiprompt_tokens.begin ()
136
138
);
137
139
}
140
+
141
+ bool load_image (const std::string & fname) {
142
+ mtmd_bitmap bitmap;
143
+ if (mtmd_helper_bitmap_init_from_file (fname.c_str (), bitmap)) {
144
+ return false ;
145
+ }
146
+ bitmaps.push_back (std::move (bitmap));
147
+ return true ;
148
+ }
138
149
};
139
150
140
151
static int generate_response (mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
141
152
llama_tokens generated_tokens;
142
153
for (int i = 0 ; i < n_predict; i++) {
143
154
if (i > n_predict || !g_is_generating || g_is_interrupted) {
144
- printf (" \n " );
155
+ LOG (" \n " );
145
156
break ;
146
157
}
147
158
@@ -150,15 +161,15 @@ static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int
150
161
common_sampler_accept (smpl, token_id, true );
151
162
152
163
if (llama_vocab_is_eog (ctx.vocab , token_id) || ctx.check_antiprompt (generated_tokens)) {
153
- printf (" \n " );
164
+ LOG (" \n " );
154
165
break ; // end of generation
155
166
}
156
167
157
- printf (" %s" , common_token_to_piece (ctx.lctx , token_id).c_str ());
168
+ LOG (" %s" , common_token_to_piece (ctx.lctx , token_id).c_str ());
158
169
fflush (stdout);
159
170
160
171
if (g_is_interrupted) {
161
- printf (" \n " );
172
+ LOG (" \n " );
162
173
break ;
163
174
}
164
175
@@ -173,25 +184,14 @@ static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int
173
184
return 0 ;
174
185
}
175
186
176
- static int eval_message (mtmd_cli_context & ctx, common_chat_msg & msg, std::vector<std::string> & images_fname, bool add_bos = false ) {
177
- std::vector<mtmd_bitmap> bitmaps;
178
-
187
+ static int eval_message (mtmd_cli_context & ctx, common_chat_msg & msg, bool add_bos = false ) {
179
188
common_chat_templates_inputs tmpl_inputs;
180
189
tmpl_inputs.messages = {msg};
181
190
tmpl_inputs.add_generation_prompt = true ;
182
191
tmpl_inputs.use_jinja = false ; // jinja is buggy here
183
192
auto formatted_chat = common_chat_templates_apply (ctx.tmpls .get (), tmpl_inputs);
184
193
LOG_DBG (" formatted_chat.prompt: %s\n " , formatted_chat.prompt .c_str ());
185
194
186
- for (auto & fname : images_fname) {
187
- mtmd_bitmap bitmap;
188
- if (mtmd_helper_bitmap_init_from_file (fname.c_str (), bitmap)) {
189
- LOG_ERR (" Unable to load image %s\n " , fname.c_str ());
190
- return 2 ; // image not found
191
- }
192
- bitmaps.push_back (std::move (bitmap));
193
- }
194
-
195
195
mtmd_input_text text;
196
196
text.text = formatted_chat.prompt ;
197
197
text.add_special = add_bos;
@@ -200,19 +200,23 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vect
200
200
201
201
if (g_is_interrupted) return 0 ;
202
202
203
- int32_t res = mtmd_tokenize (ctx.ctx_vision .get (), chunks, text, bitmaps);
203
+ int32_t res = mtmd_tokenize (ctx.ctx_vision .get (), chunks, text, ctx. bitmaps );
204
204
if (res != 0 ) {
205
205
LOG_ERR (" Unable to tokenize prompt, res = %d\n " , res);
206
206
return 1 ;
207
207
}
208
208
209
+ ctx.bitmaps .clear ();
210
+
209
211
if (mtmd_helper_eval (ctx.ctx_vision .get (), ctx.lctx , chunks, ctx.n_past , 0 , ctx.n_batch )) {
210
212
LOG_ERR (" Unable to eval prompt\n " );
211
213
return 1 ;
212
214
}
213
215
214
216
ctx.n_past += mtmd_helper_get_n_pos (chunks);
215
217
218
+ LOG (" \n " );
219
+
216
220
return 0 ;
217
221
}
218
222
@@ -235,7 +239,7 @@ int main(int argc, char ** argv) {
235
239
}
236
240
237
241
mtmd_cli_context ctx (params);
238
- printf (" %s: %s\n " , __func__, params.model .path .c_str ());
242
+ LOG (" %s: loading model : %s\n " , __func__, params.model .path .c_str ());
239
243
240
244
bool is_single_turn = !params.prompt .empty () && !params.image .empty ();
241
245
@@ -268,7 +272,12 @@ int main(int argc, char ** argv) {
268
272
common_chat_msg msg;
269
273
msg.role = " user" ;
270
274
msg.content = params.prompt ;
271
- if (eval_message (ctx, msg, params.image , true )) {
275
+ for (const auto & image : params.image ) {
276
+ if (!ctx.load_image (image)) {
277
+ return 1 ; // error is already printed by libmtmd
278
+ }
279
+ }
280
+ if (eval_message (ctx, msg, true )) {
272
281
return 1 ;
273
282
}
274
283
if (!g_is_interrupted && generate_response (ctx, smpl, n_predict)) {
@@ -283,7 +292,6 @@ int main(int argc, char ** argv) {
283
292
LOG (" \n " );
284
293
285
294
bool is_first_msg = true ;
286
- std::vector<std::string> images_fname;
287
295
std::string content;
288
296
289
297
while (!g_is_interrupted) {
@@ -308,32 +316,32 @@ int main(int argc, char ** argv) {
308
316
continue ;
309
317
}
310
318
g_is_generating = true ;
311
- if (line.find (" /image" ) == 0 ) {
319
+ if (line == " /image" || line.find (" /image " ) == 0 ) {
320
+ if (line.size () < 8 ) {
321
+ LOG_ERR (" ERR: Missing image filename\n " );
322
+ continue ;
323
+ }
312
324
std::string image = line.substr (7 );
313
- images_fname.push_back (string_strip (image));
314
- content += " <__image__>" ;
325
+ if (ctx.load_image (image)) {
326
+ LOG (" Image %s loaded\n " , image.c_str ());
327
+ content += " <__image__>" ;
328
+ }
329
+ // else, error is already printed by libmtmd
315
330
continue ;
316
331
} else {
317
332
content += line;
318
333
}
319
334
common_chat_msg msg;
320
335
msg.role = " user" ;
321
336
msg.content = content;
322
- int ret = eval_message (ctx, msg, images_fname, is_first_msg);
323
- if (g_is_interrupted) break ;
324
- if (ret == 2 ) {
325
- // non-fatal error
326
- images_fname.clear ();
327
- content.clear ();
328
- continue ;
329
- }
337
+ int ret = eval_message (ctx, msg, is_first_msg);
330
338
if (ret) {
331
339
return 1 ;
332
340
}
341
+ if (g_is_interrupted) break ;
333
342
if (generate_response (ctx, smpl, n_predict)) {
334
343
return 1 ;
335
344
}
336
- images_fname.clear ();
337
345
content.clear ();
338
346
is_first_msg = false ;
339
347
}
0 commit comments