Skip to content

Commit ffe6665

Browse files
cpumaxxroot
andauthored
llava-cli : multiple images (#6969)
Co-authored-by: root <[email protected]>
1 parent 24affa7 commit ffe6665

File tree

3 files changed

+41
-32
lines changed

3 files changed

+41
-32
lines changed

common/common.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -893,7 +893,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
893893
invalid_param = true;
894894
return true;
895895
}
896-
params.image = argv[i];
896+
params.image.emplace_back(argv[i]);
897897
return true;
898898
}
899899
if (arg == "-i" || arg == "--interactive") {
@@ -1495,7 +1495,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
14951495
printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split);
14961496
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
14971497
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");
1498-
printf(" --image IMAGE_FILE path to an image file. use with multimodal models\n");
1498+
printf(" --image IMAGE_FILE path to an image file. use with multimodal models. Specify multiple times for batching\n");
14991499
if (llama_supports_mlock()) {
15001500
printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n");
15011501
}

common/common.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,8 @@ struct gpt_params {
167167
std::string cache_type_v = "f16"; // KV cache data type for the V
168168

169169
// multimodal models (see examples/llava)
170-
std::string mmproj = ""; // path to multimodal projector
171-
std::string image = ""; // path to an image file
170+
std::string mmproj = ""; // path to multimodal projector
171+
std::vector<std::string> image; // path to image file(s)
172172
};
173173

174174
bool parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);

examples/llava/llava-cli.cpp

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,11 @@ struct llava_context {
113113
};
114114

115115
static void show_additional_info(int /*argc*/, char ** argv) {
116-
LOG_TEE("\n example usage: %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]);
116+
LOG_TEE("\n example usage: %s -m <llava-v1.5-7b/ggml-model-q5_k.gguf> --mmproj <llava-v1.5-7b/mmproj-model-f16.gguf> --image <path/to/an/image.jpg> --image <path/to/another/image.jpg> [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]);
117117
LOG_TEE(" note: a lower temperature value like 0.1 is recommended for better quality.\n");
118118
}
119119

120-
static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_params * params) {
120+
static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_params * params, const std::string & fname) {
121121

122122
// load and preprocess the image
123123
llava_image_embed * embed = NULL;
@@ -133,9 +133,9 @@ static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_para
133133
}
134134
params->prompt = remove_image_from_prompt(prompt);
135135
} else {
136-
embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->n_threads, params->image.c_str());
136+
embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->n_threads, fname.c_str());
137137
if (!embed) {
138-
LOG_TEE("%s: is %s really an image file?\n", __func__, params->image.c_str());
138+
fprintf(stderr, "%s: is %s really an image file?\n", __func__, fname.c_str());
139139
return NULL;
140140
}
141141
}
@@ -207,17 +207,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
207207
printf("\n");
208208
}
209209

210-
211-
static struct llava_context * llava_init(gpt_params * params) {
212-
const char * clip_path = params->mmproj.c_str();
213-
214-
auto prompt = params->prompt;
215-
if (prompt.empty()) {
216-
prompt = "describe the image in detail.";
217-
}
218-
219-
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
220-
210+
static struct llama_model * llava_init(gpt_params * params) {
221211
llama_backend_init();
222212
llama_numa_init(params->numa);
223213

@@ -228,6 +218,19 @@ static struct llava_context * llava_init(gpt_params * params) {
228218
LOG_TEE("%s: error: unable to load model\n" , __func__);
229219
return NULL;
230220
}
221+
return model;
222+
}
223+
224+
static struct llava_context * llava_init_context(gpt_params * params, llama_model * model) {
225+
const char * clip_path = params->mmproj.c_str();
226+
227+
auto prompt = params->prompt;
228+
if (prompt.empty()) {
229+
prompt = "describe the image in detail.";
230+
}
231+
232+
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
233+
231234

232235
llama_context_params ctx_params = llama_context_params_from_gpt_params(*params);
233236
ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings
@@ -286,24 +289,30 @@ int main(int argc, char ** argv) {
286289
show_additional_info(argc, argv);
287290
return 1;
288291
}
289-
290-
auto ctx_llava = llava_init(&params);
291-
if (ctx_llava == NULL) {
292-
LOG_TEE("%s: error: failed to init llava\n", __func__);
292+
auto model = llava_init(&params);
293+
if (model == NULL) {
294+
fprintf(stderr, "%s: error: failed to init llava model\n", __func__);
293295
return 1;
294296
}
295297

296-
auto image_embed = load_image(ctx_llava, &params);
297-
if (!image_embed) {
298-
return 1;
299-
}
298+
for (auto & image : params.image) {
299+
auto ctx_llava = llava_init_context(&params, model);
300300

301-
// process the prompt
302-
process_prompt(ctx_llava, image_embed, &params, params.prompt);
301+
auto image_embed = load_image(ctx_llava, &params, image);
302+
if (!image_embed) {
303+
std::cerr << "error: failed to load image " << image << ". Terminating\n\n";
304+
return 1;
305+
}
306+
307+
// process the prompt
308+
process_prompt(ctx_llava, image_embed, &params, params.prompt);
303309

304-
llama_print_timings(ctx_llava->ctx_llama);
310+
llama_print_timings(ctx_llava->ctx_llama);
311+
llava_image_embed_free(image_embed);
312+
ctx_llava->model = NULL;
313+
llava_free(ctx_llava);
314+
}
315+
llama_free_model(model);
305316

306-
llava_image_embed_free(image_embed);
307-
llava_free(ctx_llava);
308317
return 0;
309318
}

0 commit comments

Comments
 (0)