Skip to content

Commit 12c3bf8

Browse files
committed
Add more vocab params in file :>
1 parent e4d0d97 commit 12c3bf8

File tree

3 files changed

+35
-23
lines changed

3 files changed

+35
-23
lines changed

llama_cpp/_internals.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def detokenize(self, vocab:llama_cpp.llama_vocab_p, tokens: List[int], special:
185185
# this line removes a leading space if the first token is a beginning of sentence token
186186
return (
187187
output[1:]
188-
if len(tokens) > 0 and tokens[0] == self.token_bos() and output[0:1] == b" "
188+
if len(tokens) > 0 and tokens[0] == self.token_bos(vocab) and output[0:1] == b" "
189189
else output
190190
)
191191

@@ -630,7 +630,7 @@ def sample(
630630

631631
# apply penalties
632632
if len(self.prev) > 0:
633-
nl_token = ctx_main.model.token_nl()
633+
nl_token = ctx_main.model.token_nl(vocab)
634634
nl_logit = logits_array[nl_token]
635635
last_tokens = self.prev[-self.params.penalty_last_n :]
636636
last_tokens_size = min(len(last_tokens), self.params.penalty_last_n)

llama_cpp/llama.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ def eval_logits(self) -> Deque[List[float]]:
571571
)
572572

573573
def tokenize(
574-
self, text: bytes, add_bos: bool = True, special: bool = False
574+
self, vocab:llama_cpp.llama_vocab_p, text: bytes, add_bos: bool = True, special: bool = False
575575
) -> List[int]:
576576
"""Tokenize a string.
577577
@@ -586,10 +586,11 @@ def tokenize(
586586
Returns:
587587
A list of tokens.
588588
"""
589-
return self.tokenizer_.tokenize(text, add_bos, special)
589+
return self.tokenizer_.tokenize(vocab, text, add_bos, special)
590590

591591
def detokenize(
592592
self,
593+
vocab:llama_cpp.llama_vocab_p,
593594
tokens: List[int],
594595
prev_tokens: Optional[List[int]] = None,
595596
special: bool = False,
@@ -605,7 +606,7 @@ def detokenize(
605606
The detokenized string.
606607
"""
607608
return self.tokenizer_.detokenize(
608-
tokens, prev_tokens=prev_tokens, special=special
609+
vocab, tokens, prev_tokens=prev_tokens, special=special
609610
)
610611

611612
def set_cache(self, cache: Optional[BaseLlamaCache]):
@@ -1073,7 +1074,7 @@ def decode_batch(seq_sizes: List[int]):
10731074

10741075
# accumulate batches and encode
10751076
for text in inputs:
1076-
tokens = self.tokenize(text.encode("utf-8"))
1077+
tokens = self.tokenize(self._vocab, text.encode("utf-8"))
10771078
if truncate:
10781079
tokens = tokens[:n_batch]
10791080

@@ -1152,11 +1153,11 @@ def _create_completion(
11521153
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
11531154
created: int = int(time.time())
11541155
bos_token_id: int = self.token_bos()
1155-
cls_token_id: int = self._model.token_cls()
1156-
sep_token_id: int = self._model.token_sep()
1157-
prefix_token_id: int = self._model.token_prefix()
1158-
middle_token_id: int = self._model.token_middle()
1159-
suffix_token_id: int = self._model.token_suffix()
1156+
cls_token_id: int = self._model.token_cls(self._vocab)
1157+
sep_token_id: int = self._model.token_sep(self._vocab)
1158+
prefix_token_id: int = self._model.token_prefix(self._vocab)
1159+
middle_token_id: int = self._model.token_middle(self._vocab)
1160+
suffix_token_id: int = self._model.token_suffix(self._vocab)
11601161
add_space_prefix: bool = (
11611162
self.metadata.get("tokenizer.ggml.add_space_prefix", "true") == "true"
11621163
)
@@ -1167,13 +1168,13 @@ def _create_completion(
11671168

11681169
if (
11691170
(isinstance(prompt, list) and suffix is None)
1170-
or not self._model.add_bos_token()
1171+
or not self._model.add_bos_token(self._vocab)
11711172
or bos_tokens[:1] == [-1]
11721173
):
11731174
bos_tokens = []
11741175

11751176
if (isinstance(prompt, list) and suffix is None) or (
1176-
not self._model.add_eos_token() and sep_token_id == -1
1177+
not self._model.add_eos_token(self._vocab) and sep_token_id == -1
11771178
):
11781179
eos_tokens = []
11791180

@@ -1192,6 +1193,7 @@ def _create_completion(
11921193
) + (
11931194
(
11941195
self.tokenize(
1196+
self._vocab,
11951197
prompt.encode("utf-8"),
11961198
add_bos=False,
11971199
special=(prefix_token_id < 0 or suffix is None),
@@ -1206,7 +1208,7 @@ def _create_completion(
12061208
(
12071209
[suffix_token_id]
12081210
+ (
1209-
self.tokenize(suffix.encode("utf-8"), add_bos=False, special=False)[
1211+
self.tokenize(self._vocab, suffix.encode("utf-8"), add_bos=False, special=False)[
12101212
suffix_space_prefix:
12111213
]
12121214
if suffix
@@ -1334,14 +1336,14 @@ def logit_bias_processor(
13341336
logits_processor=logits_processor,
13351337
grammar=grammar,
13361338
):
1337-
if llama_cpp.llama_vocab_is_eog(self._model.model, token):
1338-
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
1339+
if llama_cpp.llama_vocab_is_eog(self._vocab, token):
1340+
text = self.detokenize(self._vocab, completion_tokens, prev_tokens=prompt_tokens)
13391341
finish_reason = "stop"
13401342
break
13411343

13421344
completion_tokens.append(token)
13431345

1344-
all_text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
1346+
all_text = self.detokenize(self._vocab, completion_tokens, prev_tokens=prompt_tokens)
13451347

13461348
# Contains multi-byte UTF8
13471349
for k, char in enumerate(all_text[-3:]):
@@ -1366,6 +1368,7 @@ def logit_bias_processor(
13661368
if stream:
13671369
remaining_tokens = completion_tokens[returned_tokens:]
13681370
remaining_text = self.detokenize(
1371+
self._vocab,
13691372
remaining_tokens,
13701373
prev_tokens=prompt_tokens + completion_tokens[:returned_tokens],
13711374
)
@@ -1392,6 +1395,7 @@ def logit_bias_processor(
13921395
continue
13931396
token_end_position += len(
13941397
self.detokenize(
1398+
self._vocab,
13951399
[token],
13961400
prev_tokens=prompt_tokens
13971401
+ completion_tokens[:returned_tokens],
@@ -1403,12 +1407,14 @@ def logit_bias_processor(
14031407
):
14041408
break
14051409
token_str = self.detokenize(
1410+
self._vocab,
14061411
[token],
14071412
prev_tokens=prompt_tokens
14081413
+ completion_tokens[:returned_tokens],
14091414
).decode("utf-8", errors="ignore")
14101415
text_offset = len(prompt) + len(
14111416
self.detokenize(
1417+
self._vocab,
14121418
completion_tokens[:returned_tokens],
14131419
prev_tokens=prompt_tokens
14141420
+ completion_tokens[:returned_tokens],
@@ -1433,6 +1439,7 @@ def logit_bias_processor(
14331439
logprobs_or_none = {
14341440
"tokens": [
14351441
self.detokenize(
1442+
self._vocab,
14361443
[token],
14371444
prev_tokens=prompt_tokens
14381445
+ completion_tokens[:returned_tokens],
@@ -1451,6 +1458,7 @@ def logit_bias_processor(
14511458
"choices": [
14521459
{
14531460
"text": self.detokenize(
1461+
self._vocab,
14541462
[token],
14551463
prev_tokens=prompt_tokens
14561464
+ completion_tokens[:returned_tokens],
@@ -1467,6 +1475,7 @@ def logit_bias_processor(
14671475
for i in range(1, len(remaining_tokens) + 1):
14681476
try:
14691477
bs = self.detokenize(
1478+
self._vocab,
14701479
remaining_tokens[:i],
14711480
prev_tokens=prompt_tokens
14721481
+ completion_tokens[:returned_tokens],
@@ -1505,14 +1514,14 @@ def logit_bias_processor(
15051514
}
15061515

15071516
if len(completion_tokens) >= max_tokens:
1508-
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
1517+
text = self.detokenize(self._vocab, completion_tokens, prev_tokens=prompt_tokens)
15091518
finish_reason = "length"
15101519
break
15111520

15121521
if stopping_criteria is not None and stopping_criteria(
15131522
self._input_ids, self._scores[-1, :]
15141523
):
1515-
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
1524+
text = self.detokenize(self._vocab, completion_tokens, prev_tokens=prompt_tokens)
15161525
finish_reason = "stop"
15171526

15181527
if self.verbose:
@@ -1521,6 +1530,7 @@ def logit_bias_processor(
15211530
if stream:
15221531
remaining_tokens = completion_tokens[returned_tokens:]
15231532
remaining_text = self.detokenize(
1533+
self._vocab,
15241534
remaining_tokens,
15251535
prev_tokens=prompt_tokens + completion_tokens[:returned_tokens],
15261536
)
@@ -1534,6 +1544,7 @@ def logit_bias_processor(
15341544
for token in remaining_tokens:
15351545
token_end_position += len(
15361546
self.detokenize(
1547+
self._vocab,
15371548
[token],
15381549
prev_tokens=prompt_tokens + completion_tokens[:returned_tokens],
15391550
)
@@ -1543,7 +1554,7 @@ def logit_bias_processor(
15431554
if logprobs is not None:
15441555
if token == bos_token_id:
15451556
continue
1546-
token_str = self.detokenize([token]).decode(
1557+
token_str = self.detokenize(self._vocab, [token]).decode(
15471558
"utf-8", errors="ignore"
15481559
)
15491560
text_offset = len(prompt) + len(
@@ -1569,15 +1580,15 @@ def logit_bias_processor(
15691580
top_logprob.update({token_str: current_logprobs[int(token)]})
15701581
logprobs_or_none = {
15711582
"tokens": [
1572-
self.detokenize([token]).decode("utf-8", errors="ignore")
1583+
self.detokenize(self._vocab, [token]).decode("utf-8", errors="ignore")
15731584
],
15741585
"text_offset": [text_offset],
15751586
"token_logprobs": [current_logprobs[int(token)]],
15761587
"top_logprobs": [top_logprob],
15771588
}
15781589

15791590
if token_end_position >= end:
1580-
last_text = self.detokenize([token])
1591+
last_text = self.detokenize(self._vocab, [token])
15811592
if token_end_position == end - 1:
15821593
break
15831594
returned_tokens += 1

llama_cpp/llama_tokenizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,15 @@ def __init__(self, hf_tokenizer: Any):
8181
self.hf_tokenizer = hf_tokenizer
8282

8383
def tokenize(
84-
self, text: bytes, add_bos: bool = True, special: bool = True
84+
self, vocab:llama_cpp.llama_vocab_p, text: bytes, add_bos: bool = True, special: bool = True
8585
) -> List[int]:
8686
return self.hf_tokenizer.encode(
8787
text.decode("utf-8", errors="ignore"), add_special_tokens=special
8888
)
8989

9090
def detokenize(
9191
self,
92+
vocab:llama_cpp.llama_vocab_p,
9293
tokens: List[int],
9394
prev_tokens: Optional[List[int]] = None,
9495
special: bool = False,

0 commit comments

Comments
 (0)