Skip to content

Commit 97c6372

Browse files
committed
Rewind model to longest prefix.
1 parent cabd8b8 commit 97c6372

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

llama_cpp/llama.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -390,18 +390,28 @@ def generate(
390390
"""
391391
assert self.ctx is not None
392392

393-
if (
394-
reset
395-
and len(self.eval_tokens) > 0
396-
and tuple(self.eval_tokens) == tuple(tokens[: len(self.eval_tokens)])
397-
):
398-
if self.verbose:
399-
print("Llama.generate: cache hit", file=sys.stderr)
400-
reset = False
401-
tokens = tokens[len(self.eval_tokens) :]
393+
if reset and len(self.eval_tokens) > 0:
394+
longest_prefix = 0
395+
for a, b in zip(self.eval_tokens, tokens[:-1]):
396+
if a == b:
397+
longest_prefix += 1
398+
else:
399+
break
400+
if longest_prefix > 0:
401+
if self.verbose:
402+
print("Llama.generate: prefix-match hit", file=sys.stderr)
403+
reset = False
404+
tokens = tokens[longest_prefix:]
405+
for _ in range(len(self.eval_tokens) - longest_prefix):
406+
self.eval_tokens.pop()
407+
try:
408+
self.eval_logits.pop()
409+
except IndexError:
410+
pass
402411

403412
if reset:
404413
self.reset()
414+
405415
while True:
406416
self.eval(tokens)
407417
token = self.sample(

0 commit comments

Comments
 (0)