Skip to content

Commit 333fe3c

Browse files
committed
refactor inp_raw set
1 parent ce94be1 commit 333fe3c

File tree

1 file changed

+22
-10
lines changed

1 file changed

+22
-10
lines changed

examples/llava/clip.cpp

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,6 @@ static ggml_tensor * build_rope_2d(
586586
ggml_row_size(cur->type, n_dim),
587587
ggml_row_size(cur->type, n_dim*n_head),
588588
0);
589-
// first = ggml_cont(ctx0, first);
590589
first = ggml_rope_ext(
591590
ctx0,
592591
first,
@@ -599,7 +598,7 @@ static ggml_tensor * build_rope_2d(
599598
}
600599

601600
// second half (write to tmp)
602-
ggml_tensor * second = cur;
601+
ggml_tensor * second;
603602
{
604603
second = ggml_view_3d(ctx0, cur,
605604
n_dim/2, n_head, n_pos,
@@ -2825,9 +2824,20 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
28252824

28262825
{
28272826
struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
2828-
float * data = (float *)malloc(ggml_nbytes(inp_raw));
2827+
std::vector<float> inp_data(ggml_nelements(inp_raw));
2828+
float * data = inp_data.data();
2829+
2830+
// layout of data (note: the channel dim is unrolled to better visualize the layout):
2831+
//
2832+
// ┌──W──┐
2833+
// │ H │ channel = R
2834+
// ├─────┤ │
2835+
// │ H │ channel = G
2836+
// ├─────┤ │
2837+
// │ H │ channel = B
2838+
// └─────┘ │
2839+
// ──────┘ x B
28292840

2830-
// TODO @ngxson : this whole code block is ugly, will need to be refactored
28312841
for (size_t i = 0; i < imgs.entries.size(); i++) {
28322842
const int nx = imgs.entries[i]->nx;
28332843
const int ny = imgs.entries[i]->ny;
@@ -2842,17 +2852,19 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
28422852
const int n = nx * ny;
28432853

28442854
for (int b = 0; b < batch_size; b++) {
2845-
for (int k = 0; k < 3; k++) {
2846-
for (int y = 0; y < ny; y++) {
2847-
for (int x = 0; x < nx; x++) {
2848-
data[(b * 3 * n) + k * n + y * nx + x] = imgs.entries[b]->buf[3 * (y * nx + x) + k];
2849-
}
2855+
float * batch_entry = data + b * (3*n);
2856+
for (int y = 0; y < ny; y++) {
2857+
for (int x = 0; x < nx; x++) {
2858+
size_t base_src = 3*(y * nx + x); // idx of the first channel
2859+
size_t base_dst = y * nx + x; // idx of the first channel
2860+
batch_entry[ base_dst] = imgs.entries[b]->buf[base_src ];
2861+
batch_entry[1*n + base_dst] = imgs.entries[b]->buf[base_src + 1];
2862+
batch_entry[2*n + base_dst] = imgs.entries[b]->buf[base_src + 2];
28502863
}
28512864
}
28522865
}
28532866
}
28542867
ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
2855-
free(data);
28562868
}
28572869
if (ctx->has_minicpmv_projector) {
28582870
{

0 commit comments

Comments
 (0)