Skip to content

Commit bf14ca7

Browse files
cfillionNeoZhangJianyu
authored andcommitted
llama : add llama_sampler_init for safe usage of llama_sampler_free (ggml-org#11727)
The C API in llama.h claims users can implement `llama_sampler_i` to create custom `llama_sampler`. The sampler chain takes ownership and calls `llama_sampler_free` on them. However, `llama_sampler_free` is hard-coded to use `delete`. This is undefined behavior if the object wasn't also allocated via `new` from libllama's C++ runtime. Callers in C and C-compatible languages do not use C++'s `new` operator. C++ callers may not be sharing the same heap as libllama.
1 parent cf7f1c1 commit bf14ca7

File tree

3 files changed

+70
-62
lines changed

3 files changed

+70
-62
lines changed

common/llguidance.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,10 +254,10 @@ llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * g
254254
};
255255
}
256256

257-
return new llama_sampler{
257+
return llama_sampler_init(
258258
/* .iface = */ &llama_sampler_llg_i,
259-
/* .ctx = */ ctx,
260-
};
259+
/* .ctx = */ ctx
260+
);
261261
}
262262

263263
#else

include/llama.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,11 +1114,12 @@ extern "C" {
11141114
};
11151115

11161116
struct llama_sampler {
1117-
struct llama_sampler_i * iface;
1118-
llama_sampler_context_t ctx;
1117+
const struct llama_sampler_i * iface;
1118+
llama_sampler_context_t ctx;
11191119
};
11201120

11211121
// mirror of llama_sampler_i:
1122+
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx);
11221123
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
11231124
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
11241125
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);

src/llama-sampling.cpp

Lines changed: 64 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,13 @@ static uint32_t get_rng_seed(uint32_t seed) {
316316

317317
// llama_sampler API
318318

319+
struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) {
320+
return new llama_sampler {
321+
/* .iface = */ iface,
322+
/* .ctx = */ ctx,
323+
};
324+
}
325+
319326
const char * llama_sampler_name(const struct llama_sampler * smpl) {
320327
if (!smpl->iface) {
321328
return "(null)";
@@ -347,10 +354,10 @@ struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
347354
}
348355

349356
if (smpl->ctx == nullptr) {
350-
return new llama_sampler {
357+
return llama_sampler_init(
351358
/* .iface = */ smpl->iface,
352-
/* .ctx = */ nullptr,
353-
};
359+
/* .ctx = */ nullptr
360+
);
354361
}
355362

356363
GGML_ABORT("the sampler does not support cloning");
@@ -472,15 +479,15 @@ static struct llama_sampler_i llama_sampler_chain_i = {
472479
};
473480

474481
struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
475-
return new llama_sampler {
482+
return llama_sampler_init(
476483
/* .iface = */ &llama_sampler_chain_i,
477484
/* .ctx = */ new llama_sampler_chain {
478485
/* .params = */ params,
479486
/* .samplers = */ {},
480487
/* .t_sample_us = */ 0,
481488
/* .n_sample = */ 0,
482-
},
483-
};
489+
}
490+
);
484491
}
485492

486493
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
@@ -546,10 +553,10 @@ static struct llama_sampler_i llama_sampler_greedy_i = {
546553
};
547554

548555
struct llama_sampler * llama_sampler_init_greedy() {
549-
return new llama_sampler {
556+
return llama_sampler_init(
550557
/* .iface = */ &llama_sampler_greedy_i,
551-
/* .ctx = */ nullptr,
552-
};
558+
/* .ctx = */ nullptr
559+
);
553560
}
554561

555562
// dist
@@ -608,14 +615,14 @@ static struct llama_sampler_i llama_sampler_dist_i = {
608615

609616
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
610617
auto seed_cur = get_rng_seed(seed);
611-
return new llama_sampler {
618+
return llama_sampler_init(
612619
/* .iface = */ &llama_sampler_dist_i,
613620
/* .ctx = */ new llama_sampler_dist {
614621
/* .seed = */ seed,
615622
/* .seed_cur = */ seed_cur,
616623
/* .rng = */ std::mt19937(seed_cur),
617-
},
618-
};
624+
}
625+
);
619626
}
620627

621628
// softmax
@@ -638,10 +645,10 @@ static struct llama_sampler_i llama_sampler_softmax_i = {
638645
};
639646

640647
struct llama_sampler * llama_sampler_init_softmax() {
641-
return new llama_sampler {
648+
return llama_sampler_init(
642649
/* .iface = */ &llama_sampler_softmax_i,
643-
/* .ctx = */ nullptr,
644-
};
650+
/* .ctx = */ nullptr
651+
);
645652
}
646653

647654
// top-k
@@ -678,12 +685,12 @@ static struct llama_sampler_i llama_sampler_top_k_i = {
678685
};
679686

680687
struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
681-
return new llama_sampler {
688+
return llama_sampler_init(
682689
/* .iface = */ &llama_sampler_top_k_i,
683690
/* .ctx = */ new llama_sampler_top_k {
684691
/* .k = */ k,
685-
},
686-
};
692+
}
693+
);
687694
}
688695

689696
// top-p
@@ -744,13 +751,13 @@ static struct llama_sampler_i llama_sampler_top_p_i = {
744751
};
745752

746753
struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
747-
return new llama_sampler {
754+
return llama_sampler_init(
748755
/* .iface = */ &llama_sampler_top_p_i,
749756
/* .ctx = */ new llama_sampler_top_p {
750757
/* .p = */ p,
751758
/* .min_keep = */ min_keep,
752-
},
753-
};
759+
}
760+
);
754761
}
755762

756763
// min-p
@@ -840,13 +847,13 @@ static struct llama_sampler_i llama_sampler_min_p_i = {
840847
};
841848

842849
struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
843-
return new llama_sampler {
850+
return llama_sampler_init(
844851
/* .iface = */ &llama_sampler_min_p_i,
845852
/* .ctx = */ new llama_sampler_min_p {
846853
/* .p = */ p,
847854
/* .min_keep = */ min_keep,
848-
},
849-
};
855+
}
856+
);
850857
}
851858

852859
// typical
@@ -939,13 +946,13 @@ static struct llama_sampler_i llama_sampler_typical_i = {
939946
};
940947

941948
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
942-
return new llama_sampler {
949+
return llama_sampler_init(
943950
/* .iface = */ &llama_sampler_typical_i,
944951
/* .ctx = */ new llama_sampler_typical {
945952
/* .p = */ p,
946953
/* .min_keep = */ min_keep,
947-
},
948-
};
954+
}
955+
);
949956
}
950957

951958
// temp
@@ -983,12 +990,12 @@ static struct llama_sampler_i llama_sampler_temp_i = {
983990
};
984991

985992
struct llama_sampler * llama_sampler_init_temp(float temp) {
986-
return new llama_sampler {
993+
return llama_sampler_init(
987994
/* .iface = */ &llama_sampler_temp_i,
988995
/* .ctx = */ new llama_sampler_temp {
989996
/*.temp = */ temp,
990-
},
991-
};
997+
}
998+
);
992999
}
9931000

9941001
// temp-ext
@@ -1093,14 +1100,14 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = {
10931100
};
10941101

10951102
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
1096-
return new llama_sampler {
1103+
return llama_sampler_init(
10971104
/* .iface = */ &llama_sampler_temp_ext_i,
10981105
/* .ctx = */ new llama_sampler_temp_ext {
10991106
/* .temp = */ temp,
11001107
/* .delta = */ delta,
11011108
/* .exponent = */ exponent,
1102-
},
1103-
};
1109+
}
1110+
);
11041111
}
11051112

11061113
// xtc
@@ -1185,7 +1192,7 @@ static struct llama_sampler_i llama_sampler_xtc_i = {
11851192

11861193
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
11871194
auto seed_cur = get_rng_seed(seed);
1188-
return new llama_sampler {
1195+
return llama_sampler_init(
11891196
/* .iface = */ &llama_sampler_xtc_i,
11901197
/* .ctx = */ new llama_sampler_xtc {
11911198
/* .probability = */ p,
@@ -1194,8 +1201,8 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep,
11941201
/* .seed = */ seed,
11951202
/* .seed_cur = */ seed_cur,
11961203
/* .rng = */ std::mt19937(seed_cur),
1197-
},
1198-
};
1204+
}
1205+
);
11991206
}
12001207

12011208
// mirostat
@@ -1292,7 +1299,7 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
12921299

12931300
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
12941301
auto seed_cur = get_rng_seed(seed);
1295-
return new llama_sampler {
1302+
return llama_sampler_init(
12961303
/* .iface = */ &llama_sampler_mirostat_i,
12971304
/* .ctx = */ new llama_sampler_mirostat {
12981305
/* .n_vocab = */ n_vocab,
@@ -1303,8 +1310,8 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see
13031310
/* .m = */ m,
13041311
/* .mu = */ 2.0f*tau,
13051312
/* .rng = */ std::mt19937(seed_cur),
1306-
},
1307-
};
1313+
}
1314+
);
13081315
}
13091316

13101317
// mirostat v2
@@ -1391,7 +1398,7 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
13911398

13921399
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
13931400
auto seed_cur = get_rng_seed(seed);
1394-
return new llama_sampler {
1401+
return llama_sampler_init(
13951402
/* .iface = */ &llama_sampler_mirostat_v2_i,
13961403
/* .ctx = */ new llama_sampler_mirostat_v2 {
13971404
/* .seed = */ seed,
@@ -1400,8 +1407,8 @@ struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau,
14001407
/* .eta = */ eta,
14011408
/* .mu = */ 2.0f*tau,
14021409
/* .rng = */ std::mt19937(seed_cur),
1403-
},
1404-
};
1410+
}
1411+
);
14051412
}
14061413

14071414
// grammar
@@ -1528,10 +1535,10 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
15281535
};
15291536
}
15301537

1531-
return new llama_sampler {
1538+
return llama_sampler_init(
15321539
/* .iface = */ &llama_sampler_grammar_i,
1533-
/* .ctx = */ ctx,
1534-
};
1540+
/* .ctx = */ ctx
1541+
);
15351542
}
15361543

15371544
struct llama_sampler * llama_sampler_init_grammar(
@@ -1678,7 +1685,7 @@ struct llama_sampler * llama_sampler_init_penalties(
16781685
float penalty_present) {
16791686
penalty_last_n = std::max(penalty_last_n, 0);
16801687

1681-
return new llama_sampler {
1688+
return llama_sampler_init(
16821689
/* .iface = */ &llama_sampler_penalties_i,
16831690
/* .ctx = */ new llama_sampler_penalties {
16841691
/* .penalty_last_n = */ penalty_last_n,
@@ -1687,8 +1694,8 @@ struct llama_sampler * llama_sampler_init_penalties(
16871694
/* .penalty_present = */ penalty_present,
16881695
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
16891696
/* .token_count = */ {},
1690-
},
1691-
};
1697+
}
1698+
);
16921699
}
16931700

16941701
// DRY
@@ -2041,7 +2048,7 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
20412048
}
20422049
}
20432050

2044-
return new llama_sampler {
2051+
return llama_sampler_init(
20452052
/* .iface = */ &llama_sampler_dry_i,
20462053
/* .ctx = */ new llama_sampler_dry {
20472054
/* .total_context_size = */ context_size,
@@ -2053,8 +2060,8 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
20532060
/* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
20542061
/* .dry_max_token_repeat = */ {},
20552062
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
2056-
},
2057-
};
2063+
}
2064+
);
20582065
}
20592066

20602067
// wrapper for test-sampling.cpp
@@ -2155,14 +2162,14 @@ struct llama_sampler * llama_sampler_init_logit_bias(
21552162
int32_t n_vocab,
21562163
int32_t n_logit_bias,
21572164
const llama_logit_bias * logit_bias) {
2158-
return new llama_sampler {
2165+
return llama_sampler_init(
21592166
/* .iface = */ &llama_sampler_logit_bias_i,
21602167
/* .ctx = */ new llama_sampler_logit_bias {
21612168
/* .n_vocab = */ n_vocab,
21622169
/* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
21632170
/* .to_search = */ {},
2164-
},
2165-
};
2171+
}
2172+
);
21662173
}
21672174

21682175
// infill
@@ -2377,14 +2384,14 @@ static struct llama_sampler_i llama_sampler_infill_i = {
23772384
};
23782385

23792386
struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
2380-
return new llama_sampler {
2387+
return llama_sampler_init(
23812388
/* .iface = */ &llama_sampler_infill_i,
23822389
/* .ctx = */ new llama_sampler_infill {
23832390
/* .vocab = */ vocab,
23842391
/* .buf0 = */ std::vector<char>(512),
23852392
/* .buf1 = */ std::vector<char>(512),
2386-
},
2387-
};
2393+
}
2394+
);
23882395
}
23892396

23902397
// utils

0 commit comments

Comments
 (0)