@@ -1036,23 +1036,33 @@ extern "C" {
1036
1036
1037
1037
LLAMA_API void llama_sampling_free (struct llama_sampling * smpl);
1038
1038
1039
+ // Copies the internal state of the sampler (rng, prev, params, grammar, etc.)
1039
1040
LLAMA_API struct llama_sampling * llama_sampling_cp (const struct llama_sampling * smpl);
1040
1041
1041
1042
// - clear prev token
1042
1043
// - reset grammar state
1043
1044
LLAMA_API void llama_sampling_reset (struct llama_sampling * smpl);
1044
1045
1045
- LLAMA_API void llama_sampling_set_rng_seed (struct llama_sampling * smpl, uint32_t seed);
1046
+ // Sampling parameter mutation
1047
+ // TODO: not sure if we want to keep these. Maybe it's better to keep llama_sampling immutable
1046
1048
LLAMA_API void llama_sampling_set_grammar (struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root);
1047
1049
LLAMA_API void llama_sampling_set_logit_bias (struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias);
1048
1050
1051
+ // Set the logits from which to sample.
1052
+ // This call initializes the internal token candidates array.
1053
+ // The internal candidates are implicitly used by the sampling API below when no candidates are provided.
1049
1054
LLAMA_API void llama_sampling_set_logits (
1050
1055
struct llama_sampling * smpl,
1051
1056
const float * logits);
1052
1057
1058
+ // / @details Returns the current candidate tokens.
1053
1059
LLAMA_API llama_token_data_array * llama_sampling_get_candidates (
1054
1060
struct llama_sampling * smpl);
1055
1061
1062
+ // The llama_sampling_ API below uses the parameters passed during the creation of the llama_sampling object.
1063
+ // Each function can accept an array of token candidates. If the candidates are not provided, the internal
1064
+ // candidates are used. The internal candidates are initialized by llama_sampling_set_logits().
1065
+
1056
1066
// / @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
1057
1067
LLAMA_API void llama_sampling_softmax (
1058
1068
struct llama_sampling * smpl,
@@ -1115,17 +1125,22 @@ extern "C" {
1115
1125
struct llama_sampling * smpl,
1116
1126
llama_token_data_array * candidates);
1117
1127
1118
- // / @details Sample a token using the configured samplers.
1128
+ // / @details Sample a token using the configured samplers (see "llama_sampling_params.samplers") .
1119
1129
LLAMA_API llama_token llama_sampling_sample (
1120
1130
struct llama_sampling * smpl,
1121
1131
llama_token_data_array * candidates);
1122
1132
1123
- // / @details Accepts the sampled token into the sampling context
1133
+ // / @details Accepts the sampled token into the sampling context.
1134
+ // / - adds it to "prev" tokens
1135
+ // / - updates the grammar state (if apply_grammar is true)
1124
1136
LLAMA_API void llama_sampling_accept (
1125
1137
struct llama_sampling * smpl,
1126
1138
llama_token token,
1127
1139
bool apply_grammar);
1128
1140
1141
+ // / @details Get the number of accepted tokens so far (max of n_prev)
1142
+ LLAMA_API int llama_sampling_n_prev (const struct llama_sampling * smpl);
1143
+
1129
1144
// / @details Get the ith accepted token
1130
1145
// / @param ith [0, n_prev), ith == 0 is the last accepted token.
1131
1146
// / returns LLAMA_TOKEN_NULL if ith is out of bounds
@@ -1138,9 +1153,6 @@ extern "C" {
1138
1153
// / returns LLAMA_TOKEN_NULL if there are no accepted tokens
1139
1154
LLAMA_API llama_token llama_sampling_last (const struct llama_sampling * smpl);
1140
1155
1141
- // / @details Get the number of accepted tokens (max of n_prev)
1142
- LLAMA_API int llama_sampling_n_prev (const struct llama_sampling * smpl);
1143
-
1144
1156
//
1145
1157
// Model split
1146
1158
//
0 commit comments