Skip to content

Commit 7e2b5fb

Browse files
committed
sampling : add llama_sampling_print helper
1 parent b526561 commit 7e2b5fb

File tree

4 files changed

+21
-6
lines changed

4 files changed

+21
-6
lines changed

common/sampling.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,20 @@ void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * ds
6666
dst->prev = src->prev;
6767
}
6868

69+
std::string llama_sampling_print(const llama_sampling_params & params) {
70+
char result[1024];
71+
72+
snprintf(result, sizeof(result),
73+
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
74+
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, typical_p = %.3f, temp = %.3f\n"
75+
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
76+
params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
77+
params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp,
78+
params.mirostat, params.mirostat_eta, params.mirostat_tau);
79+
80+
return std::string(result);
81+
}
82+
6983
llama_token llama_sampling_sample(
7084
struct llama_sampling_context * ctx_sampling,
7185
struct llama_context * ctx_main,

common/sampling.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ typedef struct llama_sampling_params {
3030

3131
// Classifier-Free Guidance
3232
// https://arxiv.org/abs/2306.17806
33-
std::string cfg_negative_prompt; // string to help guidance
34-
float cfg_scale = 1.f; // How strong is guidance
33+
std::string cfg_negative_prompt; // string to help guidance
34+
float cfg_scale = 1.f; // how strong is guidance
3535

3636
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
3737
} llama_sampling_params;
@@ -70,6 +70,9 @@ void llama_sampling_reset(llama_sampling_context * ctx);
7070
// Copy the sampler context
7171
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
7272

73+
// Print sampling parameters into a string
74+
std::string llama_sampling_print(const llama_sampling_params & params);
75+
7376
// this is a common sampling function used across the examples for convenience
7477
// it can serve as a starting point for implementing your own sampling function
7578
// Note: When using multiple sequences, it is the caller's responsibility to call

examples/infill/infill.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,8 +358,7 @@ int main(int argc, char ** argv) {
358358
LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str());
359359
}
360360
}
361-
LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
362-
sparams.repeat_last_n, sparams.repeat_penalty, sparams.presence_penalty, sparams.frequency_penalty, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau);
361+
LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
363362
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
364363
LOG_TEE("\n\n");
365364

examples/main/main.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,8 +415,7 @@ int main(int argc, char ** argv) {
415415
}
416416
}
417417
}
418-
LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, frequency_penalty = %f, presence_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
419-
sparams.penalty_last_n, sparams.penalty_repeat, sparams.penalty_freq, sparams.penalty_present, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau);
418+
LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
420419
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
421420
LOG_TEE("\n\n");
422421

0 commit comments

Comments
 (0)