Skip to content

Commit c51f7d8

Browse files
committed
pre-7
Seed ok top-k ok top-p ok min-p ok temperature ok grammar ko repeat/penalty ok logits ok stats ko rope ok speculative decoding ? cache ko lora ok
1 parent 9488bf1 commit c51f7d8

File tree

3 files changed

+11
-46
lines changed

3 files changed

+11
-46
lines changed

example/chat.dart

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@ void main() async {
2626
contextParams.context = 512 * 4;
2727

2828
Llama llama = Llama(
29-
"/Users/adel/Workspace/llama.cpp/models/pivot-10.7b-mistral-v0.2-rp.Q5_K_S.gguf",
29+
"/Users/adel/Workspace/llama.cpp/models/openhermes-2.5-neural-chat-v3-3-slerp.Q5_K_M.gguf",
3030
modelParams,
31-
contextParams);
31+
contextParams,
32+
samplingParams);
3233

3334
ChatMLFormat chatMLFormat = ChatMLFormat();
3435
// AlpacaFormat alpacaFormat = AlpacaFormat();
@@ -43,7 +44,7 @@ Context: Teplizumab traces its roots to a New Jersey drug company called Ortho P
4344

4445
llama.setPrompt(system + prompt);
4546
while (true) {
46-
var (token, done) = llama.getNext(samplingParams);
47+
var (token, done) = llama.getNext();
4748
String? chunk = chatMLFormat.filterResponse(token);
4849
if (chunk != null) stdout.write(token);
4950
if (done) break;
@@ -53,7 +54,6 @@ Context: Teplizumab traces its roots to a New Jersey drug company called Ortho P
5354
llama.clear();
5455
stdout.write("\n");
5556

56-
//*
5757
prompt = chatMLFormat.preparePrompt("What was the company called?");
5858
llama.setPrompt(system + prompt);
5959
while (true) {

lib/src/llama.dart

Lines changed: 6 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -66,19 +66,19 @@ class Llama {
6666
String loraBase;
6767
List<(String, double)> loraAdapters;
6868

69-
// late SamplingContext sampling;
70-
7169
/// Constructor for Llama.
7270
///
7371
/// Loads the model and context based on provided model and context parameters.
7472
Llama(String modelPath,
7573
[ModelParams? modelParams,
7674
ContextParams? contextParams,
77-
this.samplingParams,
75+
SamplingParams? samplingParams,
76+
// this.samplingParams,
7877
this.loraBase = "",
7978
this.loraAdapters = const []])
8079
: modelParams = modelParams ?? ModelParams(),
81-
contextParams = contextParams ?? ContextParams() {
80+
contextParams = contextParams ?? ContextParams(),
81+
samplingParams = samplingParams ?? SamplingParams() {
8282
lib.llama_backend_init(false);
8383
llama_model_params modelParams = this.modelParams.get();
8484

@@ -118,8 +118,6 @@ class Llama {
118118
}
119119
}
120120
malloc.free(cLoraBase);
121-
122-
// sampling = SamplingContext(this);
123121
}
124122

125123
/// Releases all resources associated with the Llama instance.
@@ -147,8 +145,8 @@ class Llama {
147145
/// An exception is thrown if the required KV cache size exceeds the context's limit.
148146
/// The function also initializes the batch for token processing.
149147
setPrompt(String prompt) {
148+
// context = lib.llama_new_context_with_model(model, contextParams.get());
150149
tokensList = tokenize(prompt, true);
151-
// temporaryInvalidCChars = [];
152150

153151
if (length != -1) {
154152
int nCtx = lib.llama_n_ctx(context);
@@ -180,8 +178,6 @@ class Llama {
180178
/// Returns a tuple with the generated text and a boolean indicating if the end-of-sequence token is reached.
181179
/// An exception is thrown if llama_decode fails during processing.
182180
(String, bool) getNext() {
183-
samplingParams ??= SamplingParams();
184-
185181
Pointer<Int32> newTokenId = calloc.allocate<Int32>(sizeOf<Int32>());
186182
final nVocab = lib.llama_n_vocab(model);
187183
final logits = lib.llama_get_logits_ith(context, batch.n_tokens - 1);
@@ -202,33 +198,12 @@ class Llama {
202198
..size = nVocab
203199
..sorted = true;
204200

205-
/*
206-
final Pointer<llama_token> nativeLastTokens =
207-
malloc.allocate<llama_token>(sizeOf<llama_token>() * lastTokens.length);
208-
for (int i = 0; i < lastTokens.length; i++) {
209-
nativeLastTokens.elementAt(i).value = i;
210-
}
211-
212-
Pointer<llama_sampling_params> sp = samplingParams!.get();
213-
lib.llama_sample_repetition_penalties(
214-
context,
215-
candidatesP,
216-
nativeLastTokens,
217-
sp.ref.penalty_last_n,
218-
sp.ref.penalty_repeat,
219-
sp.ref.penalty_freq,
220-
sp.ref.penalty_present);
221-
*/
222-
223201
SamplingContext sampling = SamplingContext(this);
224202
sampling.params = samplingParams;
225203

226-
// int minKeep = max(1, samplingParams.nProbs);
227-
// sampling.tfsZ(candidatesP, minKeep, nVocab);
228204
newTokenId.value = candidatesP.ref.data.elementAt(0).ref.id;
229205
newTokenId.value = sampling.sample(newTokenId, null);
230-
231-
// newTokenId.value = candidatesP.ref.data.elementAt(0).ref.id;
206+
sampling.accept(newTokenId.value);
232207

233208
// newTokenId.value = lib.llama_sample_token_greedy(context, candidatesP);
234209
// lastTokens.add(newTokenId);
@@ -239,13 +214,6 @@ class Llama {
239214

240215
sampling.dispose();
241216

242-
if (newTokenId.value == lib.llama_token_eos(model)) {
243-
int token = newTokenId.value;
244-
calloc.free(newTokenId);
245-
final newTokenStr = tokenToPiece(newTokenId.value);
246-
return (newTokenStr, token == lib.llama_token_eos(model));
247-
}
248-
249217
final newTokenStr = tokenToPiece(newTokenId.value);
250218

251219
batch.n_tokens = 0;
@@ -288,11 +256,8 @@ class Llama {
288256
lastTokens.clear();
289257
lib.llama_kv_cache_clear(context);
290258
batch.n_tokens = 0;
291-
tokensList.clear();
292-
lastTokens.clear();
293259
cursor = 0;
294260
decode = 0;
295-
// lib.llama
296261
}
297262

298263
// Utility methods

lib/src/sampling_context.dart

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import 'package:ffi/ffi.dart';
66
import 'package:llama_cpp_dart/llama_cpp_dart.dart';
77

88
import 'llama_cpp.dart';
9-
import 'sampling_params.dart';
109

1110
class SamplingContext {
1211
final List<Pointer<llama_token>> _prev = [];
@@ -257,6 +256,7 @@ class SamplingContext {
257256

258257
accept(int id) {
259258
if (_prev.isNotEmpty) {
259+
calloc.free(_prev[0]);
260260
_prev.removeAt(0);
261261
}
262262
Pointer<llama_token> idx =

0 commit comments

Comments
 (0)