Skip to content

Commit e84773a

Browse files
ahmedshakillCISCngxson
authored
mtmd-cli : fix out_of_range when input image path is empty (#13244)
* fix out_of_range error to keep the chat loop running * Update examples/llava/mtmd-cli.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> * mtmd-cli : load image right away * add a new line for readability * rm printf * Update examples/llava/mtmd-cli.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update examples/llava/mtmd-cli.cpp --------- Co-authored-by: Sigbjørn Skjæret <[email protected]> Co-authored-by: Xuan Son Nguyen <[email protected]> Co-authored-by: Xuan-Son Nguyen <[email protected]>
1 parent fab647e commit e84773a

File tree

2 files changed

+41
-33
lines changed

2 files changed

+41
-33
lines changed

examples/llava/mtmd-cli.cpp

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ struct mtmd_cli_context {
7272
llama_batch batch;
7373
int n_batch;
7474

75+
std::vector<mtmd_bitmap> bitmaps;
76+
7577
// note: we know that gemma3 template is "linear", meaning each turn is completely separated to another
7678
// so here we don't need to keep track of chat history
7779
common_chat_templates_ptr tmpls;
@@ -135,13 +137,22 @@ struct mtmd_cli_context {
135137
antiprompt_tokens.begin()
136138
);
137139
}
140+
141+
bool load_image(const std::string & fname) {
142+
mtmd_bitmap bitmap;
143+
if (mtmd_helper_bitmap_init_from_file(fname.c_str(), bitmap)) {
144+
return false;
145+
}
146+
bitmaps.push_back(std::move(bitmap));
147+
return true;
148+
}
138149
};
139150

140151
static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
141152
llama_tokens generated_tokens;
142153
for (int i = 0; i < n_predict; i++) {
143154
if (i > n_predict || !g_is_generating || g_is_interrupted) {
144-
printf("\n");
155+
LOG("\n");
145156
break;
146157
}
147158

@@ -150,15 +161,15 @@ static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int
150161
common_sampler_accept(smpl, token_id, true);
151162

152163
if (llama_vocab_is_eog(ctx.vocab, token_id) || ctx.check_antiprompt(generated_tokens)) {
153-
printf("\n");
164+
LOG("\n");
154165
break; // end of generation
155166
}
156167

157-
printf("%s", common_token_to_piece(ctx.lctx, token_id).c_str());
168+
LOG("%s", common_token_to_piece(ctx.lctx, token_id).c_str());
158169
fflush(stdout);
159170

160171
if (g_is_interrupted) {
161-
printf("\n");
172+
LOG("\n");
162173
break;
163174
}
164175

@@ -173,25 +184,14 @@ static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int
173184
return 0;
174185
}
175186

176-
static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vector<std::string> & images_fname, bool add_bos = false) {
177-
std::vector<mtmd_bitmap> bitmaps;
178-
187+
static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, bool add_bos = false) {
179188
common_chat_templates_inputs tmpl_inputs;
180189
tmpl_inputs.messages = {msg};
181190
tmpl_inputs.add_generation_prompt = true;
182191
tmpl_inputs.use_jinja = false; // jinja is buggy here
183192
auto formatted_chat = common_chat_templates_apply(ctx.tmpls.get(), tmpl_inputs);
184193
LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str());
185194

186-
for (auto & fname : images_fname) {
187-
mtmd_bitmap bitmap;
188-
if (mtmd_helper_bitmap_init_from_file(fname.c_str(), bitmap)) {
189-
LOG_ERR("Unable to load image %s\n", fname.c_str());
190-
return 2; // image not found
191-
}
192-
bitmaps.push_back(std::move(bitmap));
193-
}
194-
195195
mtmd_input_text text;
196196
text.text = formatted_chat.prompt;
197197
text.add_special = add_bos;
@@ -200,19 +200,23 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vect
200200

201201
if (g_is_interrupted) return 0;
202202

203-
int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, bitmaps);
203+
int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, ctx.bitmaps);
204204
if (res != 0) {
205205
LOG_ERR("Unable to tokenize prompt, res = %d\n", res);
206206
return 1;
207207
}
208208

209+
ctx.bitmaps.clear();
210+
209211
if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks, ctx.n_past, 0, ctx.n_batch)) {
210212
LOG_ERR("Unable to eval prompt\n");
211213
return 1;
212214
}
213215

214216
ctx.n_past += mtmd_helper_get_n_pos(chunks);
215217

218+
LOG("\n");
219+
216220
return 0;
217221
}
218222

@@ -235,7 +239,7 @@ int main(int argc, char ** argv) {
235239
}
236240

237241
mtmd_cli_context ctx(params);
238-
printf("%s: %s\n", __func__, params.model.path.c_str());
242+
LOG("%s: loading model: %s\n", __func__, params.model.path.c_str());
239243

240244
bool is_single_turn = !params.prompt.empty() && !params.image.empty();
241245

@@ -268,7 +272,12 @@ int main(int argc, char ** argv) {
268272
common_chat_msg msg;
269273
msg.role = "user";
270274
msg.content = params.prompt;
271-
if (eval_message(ctx, msg, params.image, true)) {
275+
for (const auto & image : params.image) {
276+
if (!ctx.load_image(image)) {
277+
return 1; // error is already printed by libmtmd
278+
}
279+
}
280+
if (eval_message(ctx, msg, true)) {
272281
return 1;
273282
}
274283
if (!g_is_interrupted && generate_response(ctx, smpl, n_predict)) {
@@ -283,7 +292,6 @@ int main(int argc, char ** argv) {
283292
LOG("\n");
284293

285294
bool is_first_msg = true;
286-
std::vector<std::string> images_fname;
287295
std::string content;
288296

289297
while (!g_is_interrupted) {
@@ -308,32 +316,32 @@ int main(int argc, char ** argv) {
308316
continue;
309317
}
310318
g_is_generating = true;
311-
if (line.find("/image") == 0) {
319+
if (line == "/image" || line.find("/image ") == 0) {
320+
if (line.size() < 8) {
321+
LOG_ERR("ERR: Missing image filename\n");
322+
continue;
323+
}
312324
std::string image = line.substr(7);
313-
images_fname.push_back(string_strip(image));
314-
content += "<__image__>";
325+
if (ctx.load_image(image)) {
326+
LOG("Image %s loaded\n", image.c_str());
327+
content += "<__image__>";
328+
}
329+
// else, error is already printed by libmtmd
315330
continue;
316331
} else {
317332
content += line;
318333
}
319334
common_chat_msg msg;
320335
msg.role = "user";
321336
msg.content = content;
322-
int ret = eval_message(ctx, msg, images_fname, is_first_msg);
323-
if (g_is_interrupted) break;
324-
if (ret == 2) {
325-
// non-fatal error
326-
images_fname.clear();
327-
content.clear();
328-
continue;
329-
}
337+
int ret = eval_message(ctx, msg, is_first_msg);
330338
if (ret) {
331339
return 1;
332340
}
341+
if (g_is_interrupted) break;
333342
if (generate_response(ctx, smpl, n_predict)) {
334343
return 1;
335344
}
336-
images_fname.clear();
337345
content.clear();
338346
is_first_msg = false;
339347
}

examples/llava/mtmd.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
590590
}
591591

592592
} else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
593-
GGML_ASSERT(!is_last && "logits for last image chunk is not yet support");
593+
GGML_ASSERT(!is_last && "logits for last image chunk is not yet supported");
594594
GGML_ASSERT(chunk.tokens_image != nullptr);
595595
int64_t t0 = ggml_time_ms();
596596
if (ctx->print_timings) {

0 commit comments

Comments
 (0)