Skip to content

Commit 94d1505

Browse files
committed
Re-add the JNA compatible sampler
1 parent 370359e commit 94d1505

File tree

4 files changed

+182
-0
lines changed

4 files changed

+182
-0
lines changed

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,5 @@ else()
3939
add_subdirectory(server)
4040
endif()
4141
add_subdirectory(export-lora)
42+
add_subdirectory(grammar)
4243
endif()

examples/grammar/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
set(TARGET grammar)
2+
add_library(${TARGET} SHARED grammar.cpp)
3+
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
4+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
5+
target_compile_definitions(${TARGET} PRIVATE LLAMA_SHARED LLAMA_BUILD)

examples/grammar/grammar.cpp

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
#include "grammar.h"
2+
#include <stdlib.h>
3+
4+
struct llama_grammar * llama_cached_parse_grammar(const char * grammar_str) {
5+
static std::unordered_map<std::string, grammar_parser::parse_state> parsed_grammar_cache;
6+
std::string key = grammar_str;
7+
8+
auto it = parsed_grammar_cache.find(key);
9+
grammar_parser::parse_state parsed_grammar;
10+
if (it != parsed_grammar_cache.end()) {
11+
// Use cached parsed grammar
12+
parsed_grammar = it->second;
13+
} else {
14+
// Parse and cache the result
15+
parsed_grammar = grammar_parser::parse(grammar_str);
16+
parsed_grammar_cache[key] = parsed_grammar;
17+
18+
// Optionally print the grammar
19+
grammar_parser::print_grammar(stderr, parsed_grammar);
20+
}
21+
22+
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
23+
24+
struct llama_grammar * grammar = NULL;
25+
grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
26+
27+
return grammar;
28+
}
29+
30+
struct llama_sampler_params llama_sampler_default_params() {
31+
return llama_sampler_params();
32+
}
33+
34+
llama_token llama_grammar_sample_token(struct llama_context * ctx,
35+
struct llama_grammar * grammar,
36+
struct llama_sampler_params params,
37+
struct llama_token_data_array * cur_p,
38+
bool reset) {
39+
40+
const int n_ctx = llama_n_ctx(ctx);
41+
42+
static std::vector<llama_token> last_tokens(n_ctx);
43+
std::fill(last_tokens.begin(), last_tokens.end(), 0);
44+
45+
if (reset) {
46+
// Clear last_tokens vector
47+
last_tokens.clear();
48+
last_tokens.resize(n_ctx, 0);
49+
}
50+
51+
const float temp = params.temp;
52+
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
53+
const float repeat_penalty = params.repeat_penalty;
54+
const float alpha_presence = params.presence_penalty;
55+
const float alpha_frequency = params.frequency_penalty;
56+
const int mirostat = params.mirostat;
57+
const float mirostat_tau = params.mirostat_tau;
58+
const float mirostat_eta = params.mirostat_eta;
59+
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx)) : params.top_k;
60+
const float top_p = params.top_p;
61+
const float tfs_z = params.tfs_z;
62+
const float typical_p = params.typical_p;
63+
const int32_t n_probs = params.n_probs;
64+
65+
66+
llama_token result = -1;
67+
68+
// apply penalties
69+
if (!last_tokens.empty()) {
70+
const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx);
71+
72+
llama_sample_repetition_penalty(ctx, cur_p,
73+
last_tokens.data() + last_tokens.size() - last_n_repeat,
74+
last_n_repeat, repeat_penalty);
75+
llama_sample_frequency_and_presence_penalties(ctx, cur_p,
76+
last_tokens.data() + last_tokens.size() - last_n_repeat,
77+
last_n_repeat, alpha_frequency, alpha_presence);
78+
79+
}
80+
81+
if (grammar != NULL) {
82+
llama_sample_grammar(ctx, cur_p, grammar);
83+
}
84+
85+
if (temp <= 0) {
86+
// Greedy sampling
87+
result = llama_sample_token_greedy(ctx, cur_p);
88+
} else {
89+
if (mirostat == 1) {
90+
static float mirostat_mu = 2.0f * mirostat_tau;
91+
const int mirostat_m = 100;
92+
llama_sample_temp(ctx, cur_p, temp);
93+
result = llama_sample_token_mirostat(ctx, cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
94+
} else if (mirostat == 2) {
95+
static float mirostat_mu = 2.0f * mirostat_tau;
96+
llama_sample_temp(ctx, cur_p, temp);
97+
result = llama_sample_token_mirostat_v2(ctx, cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
98+
} else {
99+
// Temperature sampling
100+
size_t min_keep = std::max(1, n_probs);
101+
llama_sample_top_k(ctx, cur_p, top_k, min_keep);
102+
llama_sample_tail_free(ctx, cur_p, tfs_z, min_keep);
103+
llama_sample_typical(ctx, cur_p, typical_p, min_keep);
104+
llama_sample_top_p(ctx, cur_p, top_p, min_keep);
105+
llama_sample_temp(ctx, cur_p, temp);
106+
result = llama_sample_token(ctx, cur_p);
107+
}
108+
}
109+
110+
// printf("`%d`", candidates_p.size);
111+
112+
if (grammar != NULL) {
113+
llama_grammar_accept_token(ctx, grammar, result);
114+
}
115+
116+
last_tokens.erase(last_tokens.begin());
117+
last_tokens.push_back(result);
118+
119+
return result;
120+
}

examples/grammar/grammar.h

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#ifndef GRAMMAR_H
2+
#define GRAMMAR_H
3+
4+
#include <string>
5+
#include <vector>
6+
#include <cstdint>
7+
#include <unordered_map>
8+
#include <stddef.h>
9+
#include <stdint.h>
10+
#include <stdbool.h>
11+
12+
13+
#include "llama.h"
14+
#include "grammar-parser.h"
15+
16+
#ifdef __cplusplus
17+
extern "C" {
18+
#endif
19+
// llama_sampler.h
20+
21+
#pragma once
22+
23+
24+
struct llama_sampler_params {
25+
float temp = 0.80f; // Temperature
26+
float repeat_penalty = 1.10f; // Penalty for repeated tokens
27+
int32_t repeat_last_n = 64; // Number of tokens to consider for repeat penalty
28+
float frequency_penalty = 0.00f; // Penalty for frequent tokens
29+
float presence_penalty = 0.00f; // Penalty for present tokens
30+
int32_t mirostat = 2; // Mirostat version (0 = disabled, 1 = mirostat, 2 = mirostat 2.0)
31+
float mirostat_tau = 5.00f; // Mirostat target entropy
32+
float mirostat_eta = 0.10f; // Mirostat learning rate
33+
int32_t top_k = 40; // Top-K for sampling
34+
float top_p = 0.95f; // Top-P for sampling
35+
float tfs_z = 1.0f; // TFS-Z value
36+
float typical_p = 1.0f; // Typical-P value
37+
int32_t n_probs = 0; // Number of probabilities to output (0 for no output)
38+
};
39+
40+
llama_sampler_params llama_sampler_default_params();
41+
42+
struct llama_grammar * llama_cached_parse_grammar(const char * grammar_str);
43+
44+
llama_token llama_grammar_sample_token(llama_context * ctx,
45+
llama_grammar * grammar,
46+
llama_sampler_params params,
47+
llama_token_data_array * cur_p,
48+
bool reset);
49+
50+
51+
52+
#ifdef __cplusplus
53+
}
54+
#endif
55+
56+
#endif // GRAMMAR_H

0 commit comments

Comments
 (0)