Skip to content

Add llama_sampler_init for safe usage of llama_sampler_free #11727

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions common/llguidance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,10 @@ llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * g
};
}

return new llama_sampler{
return llama_sampler_init(
/* .iface = */ &llama_sampler_llg_i,
/* .ctx = */ ctx,
};
/* .ctx = */ ctx
);
}

#else
Expand Down
5 changes: 3 additions & 2 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1114,11 +1114,12 @@ extern "C" {
};

struct llama_sampler {
struct llama_sampler_i * iface;
llama_sampler_context_t ctx;
const struct llama_sampler_i * iface;
llama_sampler_context_t ctx;
};

// mirror of llama_sampler_i:
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx);
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
Expand Down
121 changes: 64 additions & 57 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,13 @@ static uint32_t get_rng_seed(uint32_t seed) {

// llama_sampler API

struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) {
return new llama_sampler {
/* .iface = */ iface,
/* .ctx = */ ctx,
};
}

const char * llama_sampler_name(const struct llama_sampler * smpl) {
if (!smpl->iface) {
return "(null)";
Expand Down Expand Up @@ -347,10 +354,10 @@ struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
}

if (smpl->ctx == nullptr) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ smpl->iface,
/* .ctx = */ nullptr,
};
/* .ctx = */ nullptr
);
}

GGML_ABORT("the sampler does not support cloning");
Expand Down Expand Up @@ -472,15 +479,15 @@ static struct llama_sampler_i llama_sampler_chain_i = {
};

struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_chain_i,
/* .ctx = */ new llama_sampler_chain {
/* .params = */ params,
/* .samplers = */ {},
/* .t_sample_us = */ 0,
/* .n_sample = */ 0,
},
};
}
);
}

void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
Expand Down Expand Up @@ -546,10 +553,10 @@ static struct llama_sampler_i llama_sampler_greedy_i = {
};

struct llama_sampler * llama_sampler_init_greedy() {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_greedy_i,
/* .ctx = */ nullptr,
};
/* .ctx = */ nullptr
);
}

// dist
Expand Down Expand Up @@ -608,14 +615,14 @@ static struct llama_sampler_i llama_sampler_dist_i = {

struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
auto seed_cur = get_rng_seed(seed);
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_dist_i,
/* .ctx = */ new llama_sampler_dist {
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur),
},
};
}
);
}

// softmax
Expand All @@ -638,10 +645,10 @@ static struct llama_sampler_i llama_sampler_softmax_i = {
};

struct llama_sampler * llama_sampler_init_softmax() {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_softmax_i,
/* .ctx = */ nullptr,
};
/* .ctx = */ nullptr
);
}

// top-k
Expand Down Expand Up @@ -678,12 +685,12 @@ static struct llama_sampler_i llama_sampler_top_k_i = {
};

struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_top_k_i,
/* .ctx = */ new llama_sampler_top_k {
/* .k = */ k,
},
};
}
);
}

// top-p
Expand Down Expand Up @@ -744,13 +751,13 @@ static struct llama_sampler_i llama_sampler_top_p_i = {
};

struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_top_p_i,
/* .ctx = */ new llama_sampler_top_p {
/* .p = */ p,
/* .min_keep = */ min_keep,
},
};
}
);
}

// min-p
Expand Down Expand Up @@ -840,13 +847,13 @@ static struct llama_sampler_i llama_sampler_min_p_i = {
};

struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_min_p_i,
/* .ctx = */ new llama_sampler_min_p {
/* .p = */ p,
/* .min_keep = */ min_keep,
},
};
}
);
}

// typical
Expand Down Expand Up @@ -939,13 +946,13 @@ static struct llama_sampler_i llama_sampler_typical_i = {
};

struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_typical_i,
/* .ctx = */ new llama_sampler_typical {
/* .p = */ p,
/* .min_keep = */ min_keep,
},
};
}
);
}

// temp
Expand Down Expand Up @@ -983,12 +990,12 @@ static struct llama_sampler_i llama_sampler_temp_i = {
};

struct llama_sampler * llama_sampler_init_temp(float temp) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_temp_i,
/* .ctx = */ new llama_sampler_temp {
/*.temp = */ temp,
},
};
}
);
}

// temp-ext
Expand Down Expand Up @@ -1093,14 +1100,14 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = {
};

struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_temp_ext_i,
/* .ctx = */ new llama_sampler_temp_ext {
/* .temp = */ temp,
/* .delta = */ delta,
/* .exponent = */ exponent,
},
};
}
);
}

// xtc
Expand Down Expand Up @@ -1185,7 +1192,7 @@ static struct llama_sampler_i llama_sampler_xtc_i = {

struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
auto seed_cur = get_rng_seed(seed);
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_xtc_i,
/* .ctx = */ new llama_sampler_xtc {
/* .probability = */ p,
Expand All @@ -1194,8 +1201,8 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep,
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur),
},
};
}
);
}

// mirostat
Expand Down Expand Up @@ -1292,7 +1299,7 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {

struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
auto seed_cur = get_rng_seed(seed);
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_mirostat_i,
/* .ctx = */ new llama_sampler_mirostat {
/* .n_vocab = */ n_vocab,
Expand All @@ -1303,8 +1310,8 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see
/* .m = */ m,
/* .mu = */ 2.0f*tau,
/* .rng = */ std::mt19937(seed_cur),
},
};
}
);
}

// mirostat v2
Expand Down Expand Up @@ -1391,7 +1398,7 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {

struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
auto seed_cur = get_rng_seed(seed);
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_mirostat_v2_i,
/* .ctx = */ new llama_sampler_mirostat_v2 {
/* .seed = */ seed,
Expand All @@ -1400,8 +1407,8 @@ struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau,
/* .eta = */ eta,
/* .mu = */ 2.0f*tau,
/* .rng = */ std::mt19937(seed_cur),
},
};
}
);
}

// grammar
Expand Down Expand Up @@ -1528,10 +1535,10 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
};
}

return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_grammar_i,
/* .ctx = */ ctx,
};
/* .ctx = */ ctx
);
}

struct llama_sampler * llama_sampler_init_grammar(
Expand Down Expand Up @@ -1678,7 +1685,7 @@ struct llama_sampler * llama_sampler_init_penalties(
float penalty_present) {
penalty_last_n = std::max(penalty_last_n, 0);

return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_penalties_i,
/* .ctx = */ new llama_sampler_penalties {
/* .penalty_last_n = */ penalty_last_n,
Expand All @@ -1687,8 +1694,8 @@ struct llama_sampler * llama_sampler_init_penalties(
/* .penalty_present = */ penalty_present,
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
/* .token_count = */ {},
},
};
}
);
}

// DRY
Expand Down Expand Up @@ -2041,7 +2048,7 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
}
}

return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_dry_i,
/* .ctx = */ new llama_sampler_dry {
/* .total_context_size = */ context_size,
Expand All @@ -2053,8 +2060,8 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
/* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
/* .dry_max_token_repeat = */ {},
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
},
};
}
);
}

// wrapper for test-sampling.cpp
Expand Down Expand Up @@ -2155,14 +2162,14 @@ struct llama_sampler * llama_sampler_init_logit_bias(
int32_t n_vocab,
int32_t n_logit_bias,
const llama_logit_bias * logit_bias) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_logit_bias_i,
/* .ctx = */ new llama_sampler_logit_bias {
/* .n_vocab = */ n_vocab,
/* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
/* .to_search = */ {},
},
};
}
);
}

// infill
Expand Down Expand Up @@ -2377,14 +2384,14 @@ static struct llama_sampler_i llama_sampler_infill_i = {
};

struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_infill_i,
/* .ctx = */ new llama_sampler_infill {
/* .vocab = */ vocab,
/* .buf0 = */ std::vector<char>(512),
/* .buf1 = */ std::vector<char>(512),
},
};
}
);
}

// utils
Expand Down
Loading