Skip to content

Commit 9013245

Browse files
committed
more strict validate of n_embd
1 parent 1ac73f4 commit 9013245

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

tools/mtmd/mtmd.cpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ struct mtmd_context {
103103
bool print_timings;
104104
int n_threads;
105105
std::string media_marker;
106+
const bool n_embd_text;
106107

107108
// these are not token, but strings used to mark the beginning and end of image/audio embeddings
108109
std::string img_beg;
@@ -137,7 +138,8 @@ struct mtmd_context {
137138
text_model (text_model),
138139
print_timings(ctx_params.print_timings),
139140
n_threads (ctx_params.n_threads),
140-
media_marker (ctx_params.media_marker)
141+
media_marker (ctx_params.media_marker),
142+
n_embd_text (llama_model_n_embd(text_model))
141143
{
142144
if (std::string(ctx_params.image_marker) != MTMD_DEFAULT_IMAGE_MARKER) {
143145
throw std::runtime_error("custom image_marker is not supported anymore, use media_marker instead");
@@ -156,12 +158,26 @@ struct mtmd_context {
156158
if (!ctx_v && !ctx_a) {
157159
throw std::runtime_error(string_format("Failed to load CLIP model from %s\n", mmproj_fname));
158160
}
161+
162+
// if both vision and audio mmproj are present, we need to validate their n_embd
163+
if (ctx_v && ctx_a) {
164+
int n_embd_v = clip_n_mmproj_embd(ctx_v);
165+
int n_embd_a = clip_n_mmproj_embd(ctx_a);
166+
if (n_embd_v != n_embd_a) {
167+
throw std::runtime_error(string_format(
168+
"mismatch between vision and audio mmproj (n_embd_v = %d, n_embd_a = %d)\n",
169+
n_embd_v, n_embd_a));
170+
}
171+
}
159172

160-
if (llama_model_n_embd(text_model) != n_embd_projected()) {
173+
// since we already validate n_embd of vision and audio mmproj,
174+
// we can safely assume that they are the same
175+
int n_embd_clip = clip_n_mmproj_embd(ctx_v ? ctx_v : ctx_a);
176+
if (n_embd_text != n_embd_clip) {
161177
throw std::runtime_error(string_format(
162178
"mismatch between text model (n_embd = %d) and mmproj (n_embd = %d)\n"
163179
"hint: you may be using wrong mmproj\n",
164-
llama_model_n_embd(text_model), n_embd_projected()));
180+
n_embd_text, n_embd_clip));
165181
}
166182
if (ctx_v) {
167183
init_vision();
@@ -294,11 +310,6 @@ struct mtmd_context {
294310
return ctx_a ? clip_get_projector_type(ctx_a) : PROJECTOR_TYPE_UNKNOWN;
295311
}
296312

297-
// both audio and vision contexts have the n_embd output dimension
298-
int n_embd_projected() const {
299-
return clip_n_mmproj_embd(ctx_v ? ctx_v : ctx_a);
300-
}
301-
302313
~mtmd_context() {
303314
clip_free(ctx_a);
304315
clip_free(ctx_v);
@@ -716,7 +727,7 @@ int32_t mtmd_encode_chunk(mtmd_context * ctx, const mtmd_input_chunk * chunk) {
716727
LOG_ERR("%s: model does not support audio input\n", __func__);
717728
return 1;
718729
}
719-
int n_mmproj_embd = ctx->n_embd_projected();
730+
int n_mmproj_embd = ctx->n_embd_text;
720731
ctx->image_embd_v.resize(chunk->tokens_audio->n_tokens * n_mmproj_embd);
721732
bool ok = clip_image_batch_encode(
722733
ctx->ctx_a,

0 commit comments

Comments
 (0)