@@ -571,7 +571,7 @@ def eval_logits(self) -> Deque[List[float]]:
571
571
)
572
572
573
573
def tokenize (
574
- self , text : bytes , add_bos : bool = True , special : bool = False
574
+ self , vocab : llama_cpp . llama_vocab_p , text : bytes , add_bos : bool = True , special : bool = False
575
575
) -> List [int ]:
576
576
"""Tokenize a string.
577
577
@@ -586,10 +586,11 @@ def tokenize(
586
586
Returns:
587
587
A list of tokens.
588
588
"""
589
- return self .tokenizer_ .tokenize (text , add_bos , special )
589
+ return self .tokenizer_ .tokenize (vocab , text , add_bos , special )
590
590
591
591
def detokenize (
592
592
self ,
593
+ vocab :llama_cpp .llama_vocab_p ,
593
594
tokens : List [int ],
594
595
prev_tokens : Optional [List [int ]] = None ,
595
596
special : bool = False ,
@@ -605,7 +606,7 @@ def detokenize(
605
606
The detokenized string.
606
607
"""
607
608
return self .tokenizer_ .detokenize (
608
- tokens , prev_tokens = prev_tokens , special = special
609
+ vocab , tokens , prev_tokens = prev_tokens , special = special
609
610
)
610
611
611
612
def set_cache (self , cache : Optional [BaseLlamaCache ]):
@@ -1073,7 +1074,7 @@ def decode_batch(seq_sizes: List[int]):
1073
1074
1074
1075
# accumulate batches and encode
1075
1076
for text in inputs :
1076
- tokens = self .tokenize (text .encode ("utf-8" ))
1077
+ tokens = self .tokenize (self . _vocab , text .encode ("utf-8" ))
1077
1078
if truncate :
1078
1079
tokens = tokens [:n_batch ]
1079
1080
@@ -1152,11 +1153,11 @@ def _create_completion(
1152
1153
completion_id : str = f"cmpl-{ str (uuid .uuid4 ())} "
1153
1154
created : int = int (time .time ())
1154
1155
bos_token_id : int = self .token_bos ()
1155
- cls_token_id : int = self ._model .token_cls ()
1156
- sep_token_id : int = self ._model .token_sep ()
1157
- prefix_token_id : int = self ._model .token_prefix ()
1158
- middle_token_id : int = self ._model .token_middle ()
1159
- suffix_token_id : int = self ._model .token_suffix ()
1156
+ cls_token_id : int = self ._model .token_cls (self . _vocab )
1157
+ sep_token_id : int = self ._model .token_sep (self . _vocab )
1158
+ prefix_token_id : int = self ._model .token_prefix (self . _vocab )
1159
+ middle_token_id : int = self ._model .token_middle (self . _vocab )
1160
+ suffix_token_id : int = self ._model .token_suffix (self . _vocab )
1160
1161
add_space_prefix : bool = (
1161
1162
self .metadata .get ("tokenizer.ggml.add_space_prefix" , "true" ) == "true"
1162
1163
)
@@ -1167,13 +1168,13 @@ def _create_completion(
1167
1168
1168
1169
if (
1169
1170
(isinstance (prompt , list ) and suffix is None )
1170
- or not self ._model .add_bos_token ()
1171
+ or not self ._model .add_bos_token (self . _vocab )
1171
1172
or bos_tokens [:1 ] == [- 1 ]
1172
1173
):
1173
1174
bos_tokens = []
1174
1175
1175
1176
if (isinstance (prompt , list ) and suffix is None ) or (
1176
- not self ._model .add_eos_token () and sep_token_id == - 1
1177
+ not self ._model .add_eos_token (self . _vocab ) and sep_token_id == - 1
1177
1178
):
1178
1179
eos_tokens = []
1179
1180
@@ -1192,6 +1193,7 @@ def _create_completion(
1192
1193
) + (
1193
1194
(
1194
1195
self .tokenize (
1196
+ self ._vocab ,
1195
1197
prompt .encode ("utf-8" ),
1196
1198
add_bos = False ,
1197
1199
special = (prefix_token_id < 0 or suffix is None ),
@@ -1206,7 +1208,7 @@ def _create_completion(
1206
1208
(
1207
1209
[suffix_token_id ]
1208
1210
+ (
1209
- self .tokenize (suffix .encode ("utf-8" ), add_bos = False , special = False )[
1211
+ self .tokenize (self . _vocab , suffix .encode ("utf-8" ), add_bos = False , special = False )[
1210
1212
suffix_space_prefix :
1211
1213
]
1212
1214
if suffix
@@ -1334,14 +1336,14 @@ def logit_bias_processor(
1334
1336
logits_processor = logits_processor ,
1335
1337
grammar = grammar ,
1336
1338
):
1337
- if llama_cpp .llama_vocab_is_eog (self ._model . model , token ):
1338
- text = self .detokenize (completion_tokens , prev_tokens = prompt_tokens )
1339
+ if llama_cpp .llama_vocab_is_eog (self ._vocab , token ):
1340
+ text = self .detokenize (self . _vocab , completion_tokens , prev_tokens = prompt_tokens )
1339
1341
finish_reason = "stop"
1340
1342
break
1341
1343
1342
1344
completion_tokens .append (token )
1343
1345
1344
- all_text = self .detokenize (completion_tokens , prev_tokens = prompt_tokens )
1346
+ all_text = self .detokenize (self . _vocab , completion_tokens , prev_tokens = prompt_tokens )
1345
1347
1346
1348
# Contains multi-byte UTF8
1347
1349
for k , char in enumerate (all_text [- 3 :]):
@@ -1366,6 +1368,7 @@ def logit_bias_processor(
1366
1368
if stream :
1367
1369
remaining_tokens = completion_tokens [returned_tokens :]
1368
1370
remaining_text = self .detokenize (
1371
+ self ._vocab ,
1369
1372
remaining_tokens ,
1370
1373
prev_tokens = prompt_tokens + completion_tokens [:returned_tokens ],
1371
1374
)
@@ -1392,6 +1395,7 @@ def logit_bias_processor(
1392
1395
continue
1393
1396
token_end_position += len (
1394
1397
self .detokenize (
1398
+ self ._vocab ,
1395
1399
[token ],
1396
1400
prev_tokens = prompt_tokens
1397
1401
+ completion_tokens [:returned_tokens ],
@@ -1403,12 +1407,14 @@ def logit_bias_processor(
1403
1407
):
1404
1408
break
1405
1409
token_str = self .detokenize (
1410
+ self ._vocab ,
1406
1411
[token ],
1407
1412
prev_tokens = prompt_tokens
1408
1413
+ completion_tokens [:returned_tokens ],
1409
1414
).decode ("utf-8" , errors = "ignore" )
1410
1415
text_offset = len (prompt ) + len (
1411
1416
self .detokenize (
1417
+ self ._vocab ,
1412
1418
completion_tokens [:returned_tokens ],
1413
1419
prev_tokens = prompt_tokens
1414
1420
+ completion_tokens [:returned_tokens ],
@@ -1433,6 +1439,7 @@ def logit_bias_processor(
1433
1439
logprobs_or_none = {
1434
1440
"tokens" : [
1435
1441
self .detokenize (
1442
+ self ._vocab ,
1436
1443
[token ],
1437
1444
prev_tokens = prompt_tokens
1438
1445
+ completion_tokens [:returned_tokens ],
@@ -1451,6 +1458,7 @@ def logit_bias_processor(
1451
1458
"choices" : [
1452
1459
{
1453
1460
"text" : self .detokenize (
1461
+ self ._vocab ,
1454
1462
[token ],
1455
1463
prev_tokens = prompt_tokens
1456
1464
+ completion_tokens [:returned_tokens ],
@@ -1467,6 +1475,7 @@ def logit_bias_processor(
1467
1475
for i in range (1 , len (remaining_tokens ) + 1 ):
1468
1476
try :
1469
1477
bs = self .detokenize (
1478
+ self ._vocab ,
1470
1479
remaining_tokens [:i ],
1471
1480
prev_tokens = prompt_tokens
1472
1481
+ completion_tokens [:returned_tokens ],
@@ -1505,14 +1514,14 @@ def logit_bias_processor(
1505
1514
}
1506
1515
1507
1516
if len (completion_tokens ) >= max_tokens :
1508
- text = self .detokenize (completion_tokens , prev_tokens = prompt_tokens )
1517
+ text = self .detokenize (self . _vocab , completion_tokens , prev_tokens = prompt_tokens )
1509
1518
finish_reason = "length"
1510
1519
break
1511
1520
1512
1521
if stopping_criteria is not None and stopping_criteria (
1513
1522
self ._input_ids , self ._scores [- 1 , :]
1514
1523
):
1515
- text = self .detokenize (completion_tokens , prev_tokens = prompt_tokens )
1524
+ text = self .detokenize (self . _vocab , completion_tokens , prev_tokens = prompt_tokens )
1516
1525
finish_reason = "stop"
1517
1526
1518
1527
if self .verbose :
@@ -1521,6 +1530,7 @@ def logit_bias_processor(
1521
1530
if stream :
1522
1531
remaining_tokens = completion_tokens [returned_tokens :]
1523
1532
remaining_text = self .detokenize (
1533
+ self ._vocab ,
1524
1534
remaining_tokens ,
1525
1535
prev_tokens = prompt_tokens + completion_tokens [:returned_tokens ],
1526
1536
)
@@ -1534,6 +1544,7 @@ def logit_bias_processor(
1534
1544
for token in remaining_tokens :
1535
1545
token_end_position += len (
1536
1546
self .detokenize (
1547
+ self ._vocab ,
1537
1548
[token ],
1538
1549
prev_tokens = prompt_tokens + completion_tokens [:returned_tokens ],
1539
1550
)
@@ -1543,7 +1554,7 @@ def logit_bias_processor(
1543
1554
if logprobs is not None :
1544
1555
if token == bos_token_id :
1545
1556
continue
1546
- token_str = self .detokenize ([token ]).decode (
1557
+ token_str = self .detokenize (self . _vocab , [token ]).decode (
1547
1558
"utf-8" , errors = "ignore"
1548
1559
)
1549
1560
text_offset = len (prompt ) + len (
@@ -1569,15 +1580,15 @@ def logit_bias_processor(
1569
1580
top_logprob .update ({token_str : current_logprobs [int (token )]})
1570
1581
logprobs_or_none = {
1571
1582
"tokens" : [
1572
- self .detokenize ([token ]).decode ("utf-8" , errors = "ignore" )
1583
+ self .detokenize (self . _vocab , [token ]).decode ("utf-8" , errors = "ignore" )
1573
1584
],
1574
1585
"text_offset" : [text_offset ],
1575
1586
"token_logprobs" : [current_logprobs [int (token )]],
1576
1587
"top_logprobs" : [top_logprob ],
1577
1588
}
1578
1589
1579
1590
if token_end_position >= end :
1580
- last_text = self .detokenize ([token ])
1591
+ last_text = self .detokenize (self . _vocab , [token ])
1581
1592
if token_end_position == end - 1 :
1582
1593
break
1583
1594
returned_tokens += 1
0 commit comments