Skip to content

Commit 135870c

Browse files
committed
work on sampling
1 parent 23d60c6 commit 135870c

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

lib/src/llama.dart

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import 'dart:convert';
22
import 'dart:ffi';
3+
import 'dart:math';
34

45
import 'package:ffi/ffi.dart';
56
import 'package:llama_cpp_dart/src/sampling_context.dart';
@@ -203,11 +204,14 @@ class Llama {
203204

204205
// Apply sampling strategies (e.g., top-k, top-p, temperature) based on SamplingParams.
205206
if (samplingParams != null) {
206-
final last_tokens = calloc<Int32>(samplingParams!.nPrev);
207+
int minSize = min(samplingParams!.nPrev, lastTokens.length);
208+
Pointer<Int32> lastTokensP = calloc<Int32>(samplingParams!.nPrev);
209+
List<int> safeLastTokens = lastTokens.take(minSize).toList();
210+
lastTokensP.asTypedList(minSize).setAll(0, safeLastTokens);
207211
lib.llama_sample_repetition_penalties(
208212
context,
209213
candidatesP,
210-
last_tokens,
214+
lastTokensP,
211215
samplingParams!.penaltyLastN,
212216
samplingParams!.penaltyRepeat,
213217
samplingParams!.penaltyFreq,
@@ -231,6 +235,9 @@ class Llama {
231235
batch.n_tokens = 0;
232236
batchAdd(batch, newTokenId.value, cursor, [0], true);
233237

238+
// Update the last tokens list.
239+
lastTokens.add(newTokenId.value);
240+
234241
// Increment the counters.
235242
decode++;
236243
cursor++;

0 commit comments

Comments
 (0)