Skip to content

Add Granite Vision Support #11794

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
7db1c51
Add super wip scripts for multimodal granite gguf
alex-jw-brooks Jan 16, 2025
f6fbfc2
Add example for converting mmgranite to gguf
alex-jw-brooks Jan 16, 2025
8b60107
remove hardcoded path
alex-jw-brooks Jan 16, 2025
72c53dd
Add vision feature layer to gguf params
alex-jw-brooks Jan 22, 2025
4212f37
Clean up llava surgery and remove name substitution hacks
alex-jw-brooks Jan 22, 2025
77ce6f2
Add transformers llava next tensor name mapping
alex-jw-brooks Jan 22, 2025
3fc67a9
Make siglip / openclip mutuall exclusive
alex-jw-brooks Jan 22, 2025
e3c791c
Fix projector linear substitution
alex-jw-brooks Jan 26, 2025
2d45e0d
Fix linear 2 substitution index
alex-jw-brooks Jan 27, 2025
3e3eebc
Increase max flattened gridpoints to 64
alex-jw-brooks Jan 27, 2025
1ceef1a
Fix hardcoded concat for multiple feature layers
alex-jw-brooks Jan 27, 2025
c788a45
Pull vision feature layers out of gguf keys
alex-jw-brooks Jan 27, 2025
6725d6c
fix num gridpoints and use all layers
alex-jw-brooks Feb 5, 2025
403575c
Avoid dropping last image encoder layer in llava models
alex-jw-brooks Feb 5, 2025
095b836
Use 10 for max number of patches
alex-jw-brooks Feb 5, 2025
ff00515
Standardize vision feature layers
alex-jw-brooks Feb 10, 2025
eceee7f
Cleanup logs
alex-jw-brooks Feb 10, 2025
be204f0
Update comment for vision feature layer init
alex-jw-brooks Feb 10, 2025
4588b90
Update notes for alternative to legacy llm conversion script
alex-jw-brooks Feb 10, 2025
264c2ca
Fix notes rendering
alex-jw-brooks Feb 10, 2025
86b43da
Add v prefix to vision feature layer log
alex-jw-brooks Feb 10, 2025
eb54540
Use current defaults for feature layer
alex-jw-brooks Feb 12, 2025
243a899
Use constant for max gridpoints / feat layers, style fixes
alex-jw-brooks Feb 12, 2025
ee2cf62
clarify non-negative feature layers
alex-jw-brooks Feb 12, 2025
b5735ba
Remove CLIP_API from func signature
alex-jw-brooks Feb 12, 2025
3670d0e
USE MAX_IMAGE_FEATURE_LAYERS const in layer calc
alex-jw-brooks Feb 12, 2025
b973f37
Clarify feature layers are non negative ints and not uint
alex-jw-brooks Feb 12, 2025
82e8852
Fix condition for reading feature layers
alex-jw-brooks Feb 12, 2025
ee6fb4d
pop last llava layer when feature layers are unset
alex-jw-brooks Feb 12, 2025
07e7716
Fix unset vision layer 0
alex-jw-brooks Feb 13, 2025
5f69fdb
Update examples/llava/clip.cpp
alex-jw-brooks Feb 20, 2025
ab522d7
Reenable assertion for out of bounds get_rows
alex-jw-brooks Feb 20, 2025
bb3e03a
Use std vector for gridpoints and feature layers
alex-jw-brooks Feb 21, 2025
7bab305
Caculate max feature layer at load time
alex-jw-brooks Feb 21, 2025
6557523
Include base patch for granite vision allocation
alex-jw-brooks Feb 21, 2025
cce01b8
Fix trailing whitespace
alex-jw-brooks Feb 21, 2025
188bfb0
Add max num patches = 10 back for minicpmv
alex-jw-brooks Feb 21, 2025
8676316
Use unordered set to store feature layers
alex-jw-brooks Feb 24, 2025
49c0863
Use max feature layer for postnorm
alex-jw-brooks Feb 24, 2025
bec9ef1
Apply suggestions from code review
ngxson Feb 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions examples/llava/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,27 @@ python ./examples/convert_legacy_llama.py ../llava-v1.6-vicuna-7b/ --skip-unknow
```

**note** llava-1.6 needs more context than llava-1.5, at least 3000 is needed (just run it at -c 4096)

**note** llava-1.6 greatly benefits from batched prompt processing (defaults work)

**note** if the language model in step `6)` is incompatible with the legacy conversion script, the easiest way handle the LLM model conversion is to load the model in transformers, and export only the LLM from the llava next model.

```python
import os
import transformers

model_path = ...
llm_export_path = ...

tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
model = transformers.AutoModelForImageTextToText.from_pretrained(model_path)

tokenizer.save_pretrained(llm_export_path)
model.language_model.save_pretrained(llm_export_path)
```

Then, you can convert the LLM using the `convert_hf_to_gguf.py` script, which handles more LLM architectures.

## llava-cli templating and llava-1.6 prompting

llava-1.5 models all use the same vicuna prompt, here you can just add your image question like `-p "Provide a full description."`
Expand Down
108 changes: 89 additions & 19 deletions examples/llava/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include <map>
#include <regex>
#include <stdexcept>
#include <unordered_set>
#include <vector>
#include <sstream>
#include <cinttypes>
Expand Down Expand Up @@ -120,6 +121,7 @@ static std::string format(const char * fmt, ...) {
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
#define KEY_IMAGE_STD "clip.vision.image_std"
#define KEY_PROJ_TYPE "clip.projector_type"
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"

#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
Expand Down Expand Up @@ -444,8 +446,9 @@ struct clip_hparams {

char mm_patch_merge_type[32] = "flat"; // spatial_unpad or flat (default)

int32_t image_grid_pinpoints[32];
std::vector<int32_t> image_grid_pinpoints;
int32_t image_crop_resolution;
std::unordered_set<int32_t> vision_feature_layer;
};

struct clip_layer {
Expand Down Expand Up @@ -585,6 +588,7 @@ struct clip_ctx {
struct clip_vision_model vision_model;
projector_type proj_type = PROJECTOR_TYPE_MLP;

int32_t max_feature_layer;
float image_mean[3];
float image_std[3];
bool use_gelu = false;
Expand Down Expand Up @@ -651,7 +655,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
const int hidden_size = hparams.hidden_size;
const int n_head = hparams.n_head;
const int d_head = hidden_size / n_head;
int n_layer = hparams.n_layer;
const float eps = hparams.eps;
int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};

Expand Down Expand Up @@ -752,13 +755,19 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b);
}

std::vector<struct ggml_tensor *> embedding_stack;
const auto & vision_feature_layer = hparams.vision_feature_layer;

// loop over layers
if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) {
n_layer += 1;
}
for (int il = 0; il < n_layer - 1; il++) {
for (int il = 0; il < ctx->max_feature_layer; il++) {
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states

// If this is an embedding feature layer, save the output.
// NOTE: 0 index here refers to the input to the encoder.
if (vision_feature_layer.find(il) != vision_feature_layer.end()) {
embedding_stack.push_back(embeddings);
}

//const size_t nb_q_w = model.layers[il].q_w->nb[0];

// layernorm1
Expand Down Expand Up @@ -846,7 +855,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
cur = ggml_add(ctx0, embeddings, cur);

embeddings = cur;

}

// post-layernorm
Expand All @@ -857,6 +865,19 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
}

// final layer is a vision feature layer
if (vision_feature_layer.find(ctx->max_feature_layer) != vision_feature_layer.end()) {
embedding_stack.push_back(embeddings);
}

// If feature layers are explicitly set, stack them (if we have multiple)
if (!embedding_stack.empty()) {
embeddings = embedding_stack[0];
for (size_t i = 1; i < embedding_stack.size(); i++) {
embeddings = ggml_concat(ctx0, embeddings, embedding_stack[i], 0);
}
}

// llava projector
if (ctx->has_llava_projector) {
embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
Expand Down Expand Up @@ -1443,14 +1464,26 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
int idx = get_key_idx(ctx, KEY_IMAGE_GRID_PINPOINTS);
int n = gguf_get_arr_n(ctx, idx);
const int32_t * pinpoints = (const int32_t *)gguf_get_arr_data(ctx, idx);
for (int i = 0; i < 32 && i < n && pinpoints[i] != 0; ++i) {
hparams.image_grid_pinpoints[i] = pinpoints[i];
for (int i = 0; i < n; ++i) {
hparams.image_grid_pinpoints.push_back(pinpoints[i]);
}
if (n < 32)
hparams.image_grid_pinpoints[n] = 0;
} catch (std::runtime_error & /*e*/) {
hparams.image_grid_pinpoints[0]=0;
}
} catch (std::runtime_error & /*e*/) { }

// Load the vision feature layer indices if they are explicitly provided;
// if multiple vision feature layers are present, the values will be concatenated
// to form the final visual features.
// NOTE: gguf conversions should standardize the values of the vision feature layer to
// be non-negative, since we use -1 to mark values as unset here.
try {
int idx = get_key_idx(ctx, KEY_FEATURE_LAYER);
int n = gguf_get_arr_n(ctx, idx);

const int32_t * vision_feature_layer = (const int32_t *)gguf_get_arr_data(ctx, idx);

for (int i = 0; i < n; ++i) {
hparams.vision_feature_layer.insert(vision_feature_layer[i]);
}
} catch (std::runtime_error & /*e*/) { }

try {
int idx = get_key_idx(ctx, KEY_MM_PATCH_MERGE_TYPE);
Expand All @@ -1476,6 +1509,9 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
new_clip->image_std[i] = std_data[i];
}

// Calculate the deepest feature layer based on hparams and projector type
new_clip->max_feature_layer = get_deepest_feature_layer(new_clip);

if (verbosity >= 2) {
LOG_INF("\n%s: vision model hparams\n", __func__);
LOG_INF("image_size %d\n", hparams.image_size);
Expand All @@ -1489,8 +1525,13 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
LOG_INF("v_image_mean %f %f %f\n", new_clip->image_mean[0], new_clip->image_mean[1], new_clip->image_mean[2]);
LOG_INF("v_image_std %f %f %f\n", new_clip->image_std[0], new_clip->image_std[1], new_clip->image_std[2]);
LOG_INF("v_image_grid_pinpoints: ");
for (int i = 0; i < 32 && (hparams.image_grid_pinpoints[i] != 0); ++i) {
LOG_INF("%d ", hparams.image_grid_pinpoints[i]);
for (const auto & pp : hparams.image_grid_pinpoints) {
LOG_INF("%d ", pp);
}
LOG_INF("\n");
LOG_INF("v_vision_feature_layer: ");
for (const auto & feature_layer: hparams.vision_feature_layer) {
LOG_INF("%d ", feature_layer);
}
LOG_INF("\n");
LOG_INF("v_mm_patch_merge_type: %s\n", hparams.mm_patch_merge_type);
Expand Down Expand Up @@ -2235,10 +2276,10 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
}
}
} else {
if (params.image_grid_pinpoints[0] != 0) {
if (!params.image_grid_pinpoints.empty()) {
// "spatial_unpad" with "anyres" processing for llava-1.6
std::vector<std::pair<int, int>> possible_resolutions;
for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i+=2) {
for (size_t i = 0; i < params.image_grid_pinpoints.size(); i+=2) {
possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]});
}
std::pair<int, int> best_resolution = select_best_resolution({img->nx, img->ny}, possible_resolutions);
Expand Down Expand Up @@ -2404,7 +2445,14 @@ const char * clip_patch_merge_type(const struct clip_ctx * ctx) {
}

const int32_t * clip_image_grid(const struct clip_ctx * ctx) {
return ctx->vision_model.hparams.image_grid_pinpoints;
if (ctx->vision_model.hparams.image_grid_pinpoints.size()) {
return &ctx->vision_model.hparams.image_grid_pinpoints.front();
}
return nullptr;
}

size_t get_clip_image_grid_size(const struct clip_ctx * ctx) {
return ctx->vision_model.hparams.image_grid_pinpoints.size();
}

int clip_n_patches(const struct clip_ctx * ctx) {
Expand Down Expand Up @@ -2929,6 +2977,28 @@ bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
return ctx->has_qwen2vl_merger;
}

// Determine the number of encoder layers to iterate over
int get_deepest_feature_layer(const struct clip_ctx * ctx) {
// Get the index of the second to last layer; this is the
// default for models that have a llava projector
const auto & hparams = ctx->vision_model.hparams;
int n_layer = hparams.n_layer - 1;
int deepest_feature_layer = -1;

// Handle other projectors; incrementing here indicates that we
// should use the last encoder layer for the vision features.
if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) {
n_layer += 1;
}

// If we set explicit vision feature layers, only go up to the deepest one
for (const auto & feature_layer : hparams.vision_feature_layer) {
if (feature_layer > deepest_feature_layer) {
deepest_feature_layer = feature_layer;
}
}
return deepest_feature_layer < 0 ? n_layer : deepest_feature_layer;
}

bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
clip_image_f32 clip_img;
Expand Down
5 changes: 4 additions & 1 deletion examples/llava/clip.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ CLIP_API int32_t clip_hidden_size(const struct clip_ctx * ctx);
CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx);

CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx);
CLIP_API size_t get_clip_image_grid_size(const struct clip_ctx * ctx);

CLIP_API int clip_n_patches (const struct clip_ctx * ctx);
CLIP_API int clip_n_patches_by_img (const struct clip_ctx * ctx, struct clip_image_f32 * img);
Expand Down Expand Up @@ -92,11 +93,13 @@ CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, cons
CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype);

CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx);
CLIP_API bool clip_is_glm(const struct clip_ctx * ctx);
CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx);

CLIP_API int get_deepest_feature_layer(const struct clip_ctx * ctx);

CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);

CLIP_API bool clip_is_glm(const struct clip_ctx * ctx);

#ifdef __cplusplus
}
Expand Down
Loading
Loading