Skip to content

Commit 58db7a8

Browse files
authored
clip : fix batch inference for quantized models (#52)
* rm __pycache__ * gitignore __pycache__ * gitignore dist * Fix batch inference for quantized models
1 parent f4935fc commit 58db7a8

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

clip.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -883,7 +883,7 @@ bool clip_text_encode(const clip_ctx * ctx, int n_threads, const std::vector<cli
883883

884884
// layernorm1
885885
{
886-
cur = ggml_norm(ctx0, cur);
886+
cur = ggml_norm(ctx0, cur, 1e-5f);
887887

888888
cur = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].ln_1_w, cur), cur),
889889
ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));
@@ -933,7 +933,7 @@ bool clip_text_encode(const clip_ctx * ctx, int n_threads, const std::vector<cli
933933

934934
// layernorm2
935935
{
936-
cur = ggml_norm(ctx0, cur);
936+
cur = ggml_norm(ctx0, cur, 1e-5f);
937937

938938
cur = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].ln_2_w, cur), cur),
939939
ggml_repeat(ctx0, model.layers[il].ln_2_b, cur));
@@ -959,7 +959,7 @@ bool clip_text_encode(const clip_ctx * ctx, int n_threads, const std::vector<cli
959959

960960
// final -layer_norm
961961
{
962-
embeddings = ggml_norm(ctx0, embeddings);
962+
embeddings = ggml_norm(ctx0, embeddings, 1e-5f);
963963

964964
embeddings = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.post_ln_w, embeddings), embeddings),
965965
ggml_repeat(ctx0, model.post_ln_b, embeddings));
@@ -1136,7 +1136,7 @@ bool clip_image_batch_encode(const clip_ctx * ctx, int n_threads, const std::vec
11361136

11371137
// pre-layernorm
11381138
{
1139-
embeddings = ggml_norm(ctx0, embeddings);
1139+
embeddings = ggml_norm(ctx0, embeddings, 1e-5f);
11401140

11411141
embeddings = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.pre_ln_w, embeddings), embeddings),
11421142
ggml_repeat(ctx0, model.pre_ln_b, embeddings));
@@ -1152,7 +1152,7 @@ bool clip_image_batch_encode(const clip_ctx * ctx, int n_threads, const std::vec
11521152

11531153
// layernorm1
11541154
{
1155-
cur = ggml_norm(ctx0, cur);
1155+
cur = ggml_norm(ctx0, cur, 1e-5f);
11561156

11571157
cur = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].ln_1_w, cur), cur),
11581158
ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));
@@ -1202,7 +1202,7 @@ bool clip_image_batch_encode(const clip_ctx * ctx, int n_threads, const std::vec
12021202

12031203
// layernorm2
12041204
{
1205-
cur = ggml_norm(ctx0, cur);
1205+
cur = ggml_norm(ctx0, cur, 1e-5f);
12061206

12071207
cur = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].ln_2_w, cur), cur),
12081208
ggml_repeat(ctx0, model.layers[il].ln_2_b, cur));
@@ -1235,7 +1235,7 @@ bool clip_image_batch_encode(const clip_ctx * ctx, int n_threads, const std::vec
12351235

12361236
// post-layernorm
12371237
{
1238-
embeddings = ggml_norm(ctx0, embeddings);
1238+
embeddings = ggml_norm(ctx0, embeddings, 1e-4f);
12391239

12401240
embeddings = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.post_ln_w, embeddings), embeddings),
12411241
ggml_repeat(ctx0, model.post_ln_b, embeddings));
@@ -1260,6 +1260,7 @@ bool clip_image_batch_encode(const clip_ctx * ctx, int n_threads, const std::vec
12601260
// run the computation
12611261
ggml_build_forward_expand(&gf, output);
12621262
ggml_cplan cplan = ggml_graph_plan(&gf, n_threads);
1263+
cplan.work_size *= batch_size;
12631264
if (cplan.work_size != 0) {
12641265
cplan.work_data = (uint8_t *)malloc(cplan.work_size);
12651266
}
@@ -1395,16 +1396,18 @@ bool softmax_with_sorting(float * arr, int length, float * sorted_scores, int *
13951396
}
13961397

13971398
// Calculate softmax probabilities
1399+
/*
13981400
float max_val = arr[0];
13991401
for (int i = 1; i < length; i++) {
14001402
if (arr[i] > max_val) {
14011403
max_val = arr[i];
14021404
}
14031405
}
1406+
*/
14041407

14051408
float sum = 0.0;
14061409
for (int i = 0; i < length; i++) {
1407-
arr[i] = exp(arr[i] - max_val);
1410+
arr[i] = exp(arr[i]);
14081411
sum += arr[i];
14091412
}
14101413

ggml

Submodule ggml updated from 1a5d5f3 to dd1d575

0 commit comments

Comments
 (0)