Skip to content

Commit 584ef0e

Browse files
committed
llama : add comments [no ci]
1 parent 5dde421 commit 584ef0e

File tree

4 files changed

+30
-19
lines changed

4 files changed

+30
-19
lines changed

include/llama.h

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,23 +1036,33 @@ extern "C" {
10361036

10371037
LLAMA_API void llama_sampling_free(struct llama_sampling * smpl);
10381038

1039+
// Copies the internal state of the sampler (rng, prev, params, grammar, etc.)
10391040
LLAMA_API struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl);
10401041

10411042
// - clear prev token
10421043
// - reset grammar state
10431044
LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl);
10441045

1045-
LLAMA_API void llama_sampling_set_rng_seed (struct llama_sampling * smpl, uint32_t seed);
1046+
// Sampling parameter mutation
1047+
// TODO: not sure if we want to keep these. Maybe it's better to keep llama_sampling immutable
10461048
LLAMA_API void llama_sampling_set_grammar (struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root);
10471049
LLAMA_API void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias);
10481050

1051+
// Set the logits from which to sample.
1052+
// This call initializes the internal token candidates array.
1053+
// The internal candidates are implicitly used by the sampling API below when no candidates are provided.
10491054
LLAMA_API void llama_sampling_set_logits(
10501055
struct llama_sampling * smpl,
10511056
const float * logits);
10521057

1058+
/// @details Returns the current candidate tokens.
10531059
LLAMA_API llama_token_data_array * llama_sampling_get_candidates(
10541060
struct llama_sampling * smpl);
10551061

1062+
// The llama_sampling_ API below uses the parameters passed during the creation of the llama_sampling object.
1063+
// Each function can accept an array of token candidates. If the candidates are not provided, the internal
1064+
// candidates are used. The internal candidates are initialized by llama_sampling_set_logits().
1065+
10561066
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
10571067
LLAMA_API void llama_sampling_softmax(
10581068
struct llama_sampling * smpl,
@@ -1115,17 +1125,22 @@ extern "C" {
11151125
struct llama_sampling * smpl,
11161126
llama_token_data_array * candidates);
11171127

1118-
/// @details Sample a token using the configured samplers.
1128+
/// @details Sample a token using the configured samplers (see "llama_sampling_params.samplers").
11191129
LLAMA_API llama_token llama_sampling_sample(
11201130
struct llama_sampling * smpl,
11211131
llama_token_data_array * candidates);
11221132

1123-
/// @details Accepts the sampled token into the sampling context
1133+
/// @details Accepts the sampled token into the sampling context.
1134+
/// - adds it to "prev" tokens
1135+
/// - updates the grammar state (if apply_grammar is true)
11241136
LLAMA_API void llama_sampling_accept(
11251137
struct llama_sampling * smpl,
11261138
llama_token token,
11271139
bool apply_grammar);
11281140

1141+
/// @details Get the number of accepted tokens so far (max of n_prev)
1142+
LLAMA_API int llama_sampling_n_prev(const struct llama_sampling * smpl);
1143+
11291144
/// @details Get the ith accepted token
11301145
/// @param ith [0, n_prev), ith == 0 is the last accepted token.
11311146
/// returns LLAMA_TOKEN_NULL if ith is out of bounds
@@ -1138,9 +1153,6 @@ extern "C" {
11381153
/// returns LLAMA_TOKEN_NULL if there are no accepted tokens
11391154
LLAMA_API llama_token llama_sampling_last(const struct llama_sampling * smpl);
11401155

1141-
/// @details Get the number of accepted tokens (max of n_prev)
1142-
LLAMA_API int llama_sampling_n_prev(const struct llama_sampling * smpl);
1143-
11441156
//
11451157
// Model split
11461158
//

src/llama-sampling.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,10 +186,12 @@ void llama_sampling_top_k_impl(llama_token_data_array * candidates, int32_t k, s
186186
int ib = nbuckets - 1;
187187
for ( ; ib >= 0; --ib) {
188188
nhave += histo[ib];
189-
if (nhave >= k) break;
189+
if (nhave >= k) {
190+
break;
191+
}
190192
}
191193
std::vector<llama_token_data> tmp_tokens(nhave);
192-
auto ptr = tmp_tokens.data();
194+
auto * ptr = tmp_tokens.data();
193195
std::vector<llama_token_data*> bucket_ptrs;
194196
bucket_ptrs.reserve(nbuckets - ib);
195197
for (int j = nbuckets - 1; j >= ib; --j) {
@@ -573,6 +575,7 @@ llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_token_data_array
573575
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
574576
return candidate.id == X;
575577
}));
578+
576579
float observed_surprise = -log2f(candidates->data[X_idx].p);
577580
float e = observed_surprise - tau;
578581

src/llama-sampling.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,10 @@ llama_token llama_sampling_sample_mirostat_impl (struct llama_token_data_array
9898
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
9999
llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu);
100100

101-
llama_token llama_sampling_sample_greedy_impl (struct llama_token_data_array * candidates);
102-
llama_token llama_sampling_sample_dist_impl (struct llama_token_data_array * candidates, std::mt19937 & rng);
101+
llama_token llama_sampling_sample_greedy_impl(struct llama_token_data_array * candidates);
102+
llama_token llama_sampling_sample_dist_impl (struct llama_token_data_array * candidates, std::mt19937 & rng);
103103

104104
void llama_sampling_accept_impl(struct llama_sampling & smpl, llama_token token, bool apply_grammar);
105105

106-
llama_token llama_sampling_prev_impl(const struct llama_sampling & smpl, int ith);
106+
llama_token llama_sampling_prev_impl (const struct llama_sampling & smpl, int ith);
107107
int llama_sampling_n_prev_impl(const struct llama_sampling & smpl);

src/llama.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20106,10 +20106,6 @@ void llama_sampling_reset(struct llama_sampling * smpl) {
2010620106
llama_sampling_reset_impl(*smpl);
2010720107
}
2010820108

20109-
void llama_sampling_set_rng_seed(struct llama_sampling * smpl, uint32_t seed) {
20110-
llama_sampling_set_rng_seed_impl(*smpl, seed);
20111-
}
20112-
2011320109
void llama_sampling_set_grammar(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root) {
2011420110
llama_sampling_set_grammar_impl(*smpl, grammar_str, grammar_root);
2011520111
}
@@ -20414,6 +20410,10 @@ void llama_sampling_accept(
2041420410
smpl->n_accept++;
2041520411
}
2041620412

20413+
int llama_sampling_n_prev(const struct llama_sampling * smpl) {
20414+
return llama_sampling_n_prev_impl(*smpl);
20415+
}
20416+
2041720417
llama_token llama_sampling_prev(const struct llama_sampling * smpl, int32_t ith) {
2041820418
return llama_sampling_prev_impl(*smpl, ith);
2041920419
}
@@ -20422,10 +20422,6 @@ llama_token llama_sampling_last(const struct llama_sampling * smpl) {
2042220422
return llama_sampling_prev_impl(*smpl, 0);
2042320423
}
2042420424

20425-
int llama_sampling_n_prev(const struct llama_sampling * smpl) {
20426-
return llama_sampling_n_prev_impl(*smpl);
20427-
}
20428-
2042920425
//
2043020426
// model split
2043120427
//

0 commit comments

Comments
 (0)