Skip to content

Commit bd35cb0

Browse files
authored
feat: remove a sampler from a chain (#9445)
* feat: remove a sampler from a chain * fix: return removed sampler * fix: safer casting
1 parent 7820364 commit bd35cb0

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

include/llama.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,6 +1056,9 @@ extern "C" {
10561056
LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i);
10571057
LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain);
10581058

1059+
// after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed
1060+
LLAMA_API struct llama_sampler * llama_sampler_chain_remove( struct llama_sampler * chain, int32_t i);
1061+
10591062
// available samplers:
10601063

10611064
LLAMA_API struct llama_sampler * llama_sampler_init_greedy (void);

src/llama-sampling.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,13 +349,26 @@ void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler
349349
struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
350350
const auto * p = (const llama_sampler_chain *) chain->ctx;
351351

352-
if (i < 0 || i >= (int32_t) p->samplers.size()) {
352+
if (i < 0 || (size_t) i >= p->samplers.size()) {
353353
return nullptr;
354354
}
355355

356356
return p->samplers[i];
357357
}
358358

359+
struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
360+
auto * p = (llama_sampler_chain *) chain->ctx;
361+
362+
if (i < 0 || (size_t) i >= p->samplers.size()) {
363+
return nullptr;
364+
}
365+
366+
auto * result = p->samplers[i];
367+
p->samplers.erase(p->samplers.begin() + i);
368+
369+
return result;
370+
}
371+
359372
int llama_sampler_chain_n(const struct llama_sampler * chain) {
360373
const auto * p = (const llama_sampler_chain *) chain->ctx;
361374

0 commit comments

Comments
 (0)