Skip to content

Commit 23d60c6

Browse files
committed
improve sampling
1 parent f444ba2 commit 23d60c6

File tree

1 file changed

+40
-18
lines changed

1 file changed

+40
-18
lines changed

lib/src/llama.dart

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -178,53 +178,75 @@ class Llama {
178178
/// Returns a tuple with the generated text and a boolean indicating if the end-of-sequence token is reached.
179179
/// An exception is thrown if llama_decode fails during processing.
180180
(String, bool) getNext() {
181-
Pointer<Int32> newTokenId = calloc.allocate<Int32>(sizeOf<Int32>());
181+
// Allocate memory for the new token ID.
182+
Pointer<Int32> newTokenId = calloc<Int32>();
183+
184+
// Get the number of vocabulary items.
182185
final int nVocab = lib.llama_n_vocab(model);
186+
187+
// Get the logits for the last token generated.
183188
final logits = lib.llama_get_logits_ith(context, batch.n_tokens - 1);
184189

190+
// Prepare candidates array to hold token data for all vocabulary items.
185191
final Pointer<llama_token_data> candidates = calloc<llama_token_data>(nVocab);
186192
for (int tokenId = 0; tokenId < nVocab; tokenId++) {
187193
candidates[tokenId].id = tokenId;
188194
candidates[tokenId].logit = logits[tokenId];
189-
candidates[tokenId].p = 0.0;
195+
candidates[tokenId].p = 0.0; // Initialize probabilities to 0.0.
190196
}
191197

198+
// Create a structure to hold the candidates array.
192199
final Pointer<llama_token_data_array> candidatesP = calloc<llama_token_data_array>();
193200
candidatesP.ref.data = candidates;
194201
candidatesP.ref.size = nVocab;
195202
candidatesP.ref.sorted = false;
196203

197-
SamplingContext sampling = SamplingContext(this);
198-
sampling.params = samplingParams;
199-
200-
newTokenId.value = candidatesP.ref.data.elementAt(0).ref.id;
201-
newTokenId.value = sampling.sample(newTokenId, null);
202-
sampling.accept(newTokenId.value);
203-
204-
// newTokenId.value = lib.llama_sample_token_greedy(context, candidatesP);
205-
// lastTokens.add(newTokenId);
204+
// Apply sampling strategies (e.g., top-k, top-p, temperature) based on SamplingParams.
205+
if (samplingParams != null) {
206+
final last_tokens = calloc<Int32>(samplingParams!.nPrev);
207+
lib.llama_sample_repetition_penalties(
208+
context,
209+
candidatesP,
210+
last_tokens,
211+
samplingParams!.penaltyLastN,
212+
samplingParams!.penaltyRepeat,
213+
samplingParams!.penaltyFreq,
214+
samplingParams!.penaltyPresent
215+
);
216+
lib.llama_sample_top_k(context, candidatesP, samplingParams!.topK, 1);
217+
lib.llama_sample_top_p(context, candidatesP, samplingParams!.topP, 1);
218+
lib.llama_sample_temperature(context, candidatesP, samplingParams!.temp);
219+
}
206220

207-
// calloc.free(nativeLastTokens);
208-
calloc.free(candidates);
209-
calloc.free(candidatesP);
221+
// Sample a token from the adjusted logits/probabilities.
222+
newTokenId.value = lib.llama_sample_token(context, candidatesP);
210223

211-
sampling.dispose();
224+
// Check if the sampled token is an EOS token.
225+
bool isEOSToken = newTokenId.value == lib.llama_token_eos(model);
212226

227+
// Convert the token ID to its string representation.
213228
final newTokenStr = tokenToPiece(newTokenId.value);
214229

230+
// Update the batch and context for the next token generation.
215231
batch.n_tokens = 0;
216232
batchAdd(batch, newTokenId.value, cursor, [0], true);
217233

234+
// Increment the counters.
218235
decode++;
219236
cursor++;
220237

238+
// Process the next token.
221239
if (lib.llama_decode(context, batch) != 0) {
222-
throw Exception("failed to evaluate llama!");
240+
throw Exception("Failed to evaluate Llama!");
223241
}
224242

225-
int token = newTokenId.value;
243+
// Free allocated memory.
226244
calloc.free(newTokenId);
227-
return (newTokenStr, token == lib.llama_token_eos(model));
245+
calloc.free(candidates);
246+
calloc.free(candidatesP);
247+
248+
// Return the generated text and whether the EOS token was reached.
249+
return (newTokenStr, isEOSToken);
228250
}
229251

230252
/// Asynchronously generates text based on a given prompt.

0 commit comments

Comments
 (0)