File tree Expand file tree Collapse file tree 1 file changed +19
-9
lines changed Expand file tree Collapse file tree 1 file changed +19
-9
lines changed Original file line number Diff line number Diff line change @@ -390,18 +390,28 @@ def generate(
390
390
"""
391
391
assert self .ctx is not None
392
392
393
- if (
394
- reset
395
- and len (self .eval_tokens ) > 0
396
- and tuple (self .eval_tokens ) == tuple (tokens [: len (self .eval_tokens )])
397
- ):
398
- if self .verbose :
399
- print ("Llama.generate: cache hit" , file = sys .stderr )
400
- reset = False
401
- tokens = tokens [len (self .eval_tokens ) :]
393
+ if reset and len (self .eval_tokens ) > 0 :
394
+ longest_prefix = 0
395
+ for a , b in zip (self .eval_tokens , tokens [:- 1 ]):
396
+ if a == b :
397
+ longest_prefix += 1
398
+ else :
399
+ break
400
+ if longest_prefix > 0 :
401
+ if self .verbose :
402
+ print ("Llama.generate: prefix-match hit" , file = sys .stderr )
403
+ reset = False
404
+ tokens = tokens [longest_prefix :]
405
+ for _ in range (len (self .eval_tokens ) - longest_prefix ):
406
+ self .eval_tokens .pop ()
407
+ try :
408
+ self .eval_logits .pop ()
409
+ except IndexError :
410
+ pass
402
411
403
412
if reset :
404
413
self .reset ()
414
+
405
415
while True :
406
416
self .eval (tokens )
407
417
token = self .sample (
You can’t perform that action at this time.
0 commit comments