@@ -178,53 +178,75 @@ class Llama {
178
178
/// Returns a tuple with the generated text and a boolean indicating if the end-of-sequence token is reached.
179
179
/// An exception is thrown if llama_decode fails during processing.
180
180
(String , bool ) getNext () {
181
- Pointer <Int32 > newTokenId = calloc.allocate <Int32 >(sizeOf <Int32 >());
181
+ // Allocate memory for the new token ID.
182
+ Pointer <Int32 > newTokenId = calloc <Int32 >();
183
+
184
+ // Get the number of vocabulary items.
182
185
final int nVocab = lib.llama_n_vocab (model);
186
+
187
+ // Get the logits for the last token generated.
183
188
final logits = lib.llama_get_logits_ith (context, batch.n_tokens - 1 );
184
189
190
+ // Prepare candidates array to hold token data for all vocabulary items.
185
191
final Pointer <llama_token_data> candidates = calloc <llama_token_data>(nVocab);
186
192
for (int tokenId = 0 ; tokenId < nVocab; tokenId++ ) {
187
193
candidates[tokenId].id = tokenId;
188
194
candidates[tokenId].logit = logits[tokenId];
189
- candidates[tokenId].p = 0.0 ;
195
+ candidates[tokenId].p = 0.0 ; // Initialize probabilities to 0.0.
190
196
}
191
197
198
+ // Create a structure to hold the candidates array.
192
199
final Pointer <llama_token_data_array> candidatesP = calloc <llama_token_data_array>();
193
200
candidatesP.ref.data = candidates;
194
201
candidatesP.ref.size = nVocab;
195
202
candidatesP.ref.sorted = false ;
196
203
197
- SamplingContext sampling = SamplingContext (this );
198
- sampling.params = samplingParams;
199
-
200
- newTokenId.value = candidatesP.ref.data.elementAt (0 ).ref.id;
201
- newTokenId.value = sampling.sample (newTokenId, null );
202
- sampling.accept (newTokenId.value);
203
-
204
- // newTokenId.value = lib.llama_sample_token_greedy(context, candidatesP);
205
- // lastTokens.add(newTokenId);
204
+ // Apply sampling strategies (e.g., top-k, top-p, temperature) based on SamplingParams.
205
+ if (samplingParams != null ) {
206
+ final last_tokens = calloc <Int32 >(samplingParams! .nPrev);
207
+ lib.llama_sample_repetition_penalties (
208
+ context,
209
+ candidatesP,
210
+ last_tokens,
211
+ samplingParams! .penaltyLastN,
212
+ samplingParams! .penaltyRepeat,
213
+ samplingParams! .penaltyFreq,
214
+ samplingParams! .penaltyPresent
215
+ );
216
+ lib.llama_sample_top_k (context, candidatesP, samplingParams! .topK, 1 );
217
+ lib.llama_sample_top_p (context, candidatesP, samplingParams! .topP, 1 );
218
+ lib.llama_sample_temperature (context, candidatesP, samplingParams! .temp);
219
+ }
206
220
207
- // calloc.free(nativeLastTokens);
208
- calloc.free (candidates);
209
- calloc.free (candidatesP);
221
+ // Sample a token from the adjusted logits/probabilities.
222
+ newTokenId.value = lib.llama_sample_token (context, candidatesP);
210
223
211
- sampling.dispose ();
224
+ // Check if the sampled token is an EOS token.
225
+ bool isEOSToken = newTokenId.value == lib.llama_token_eos (model);
212
226
227
+ // Convert the token ID to its string representation.
213
228
final newTokenStr = tokenToPiece (newTokenId.value);
214
229
230
+ // Update the batch and context for the next token generation.
215
231
batch.n_tokens = 0 ;
216
232
batchAdd (batch, newTokenId.value, cursor, [0 ], true );
217
233
234
+ // Increment the counters.
218
235
decode++ ;
219
236
cursor++ ;
220
237
238
+ // Process the next token.
221
239
if (lib.llama_decode (context, batch) != 0 ) {
222
- throw Exception ("failed to evaluate llama !" );
240
+ throw Exception ("Failed to evaluate Llama !" );
223
241
}
224
242
225
- int token = newTokenId.value;
243
+ // Free allocated memory.
226
244
calloc.free (newTokenId);
227
- return (newTokenStr, token == lib.llama_token_eos (model));
245
+ calloc.free (candidates);
246
+ calloc.free (candidatesP);
247
+
248
+ // Return the generated text and whether the EOS token was reached.
249
+ return (newTokenStr, isEOSToken);
228
250
}
229
251
230
252
/// Asynchronously generates text based on a given prompt.
0 commit comments