6
6
7
7
8
8
import argparse
9
- from typing import Optional
9
+
10
+ from typing import Optional , Union
10
11
11
12
import lm_eval
12
13
import torch
13
14
15
+ from executorch .examples .models .llama2 .tokenizer .tiktoken import Tokenizer as Tiktoken
16
+ from executorch .examples .models .llama2 .tokenizer .tokenizer import (
17
+ Tokenizer as SentencePieceTokenizer ,
18
+ )
19
+
14
20
from lm_eval .api .model import LM
15
21
from lm_eval .evaluator import evaluate
16
22
from lm_eval .models .huggingface import HFLM as eval_wrapper
17
23
from lm_eval .tasks import get_task_dict
18
- from sentencepiece import SentencePieceProcessor
24
+
19
25
from torch import nn
20
26
21
27
from .builder import LlamaEdgeManager
@@ -33,7 +39,7 @@ class GPTFastEvalWrapper(eval_wrapper):
33
39
def __init__ (
34
40
self ,
35
41
model : nn .Module ,
36
- tokenizer : SentencePieceProcessor ,
42
+ tokenizer : Union [ SentencePieceTokenizer , Tiktoken ] ,
37
43
max_seq_length : Optional [int ] = None ,
38
44
):
39
45
super ().__init__ ()
@@ -46,7 +52,7 @@ def __init__(
46
52
47
53
@property
48
54
def eot_token_id (self ):
49
- return self ._tokenizer .eos_id ()
55
+ return self ._tokenizer .eos_id
50
56
51
57
@property
52
58
def max_length (self ):
@@ -65,7 +71,7 @@ def device(self):
65
71
return self ._device
66
72
67
73
def tok_encode (self , string : str , ** kwargs ):
68
- tokens = [ self ._tokenizer .bos_id ()] + self . _tokenizer . encode (string )
74
+ tokens = self ._tokenizer .encode (string , bos = True , eos = False )
69
75
encoded = torch .tensor (tokens , dtype = torch .int , device = self .device )
70
76
# encoded is a pytorch tensor, but some internal logic in the
71
77
# eval harness expects it to be a list instead
@@ -93,7 +99,7 @@ class ETEagerEvalWrapper(GPTFastEvalWrapper):
93
99
def __init__ (
94
100
self ,
95
101
model : str ,
96
- tokenizer : SentencePieceProcessor ,
102
+ tokenizer : Union [ SentencePieceTokenizer , Tiktoken ] ,
97
103
max_seq_length : Optional [int ] = None ,
98
104
):
99
105
super ().__init__ (None , tokenizer , max_seq_length )
@@ -120,7 +126,7 @@ class ETRunnerEvalWrapper(GPTFastEvalWrapper):
120
126
def __init__ (
121
127
self ,
122
128
model : str ,
123
- tokenizer : SentencePieceProcessor ,
129
+ tokenizer : Union [ SentencePieceTokenizer , Tiktoken ] ,
124
130
tokenizer_bin : str ,
125
131
max_seq_length : Optional [int ] = None ,
126
132
):
@@ -183,7 +189,11 @@ def gen_eval_wrapper(
183
189
Returns:
184
190
eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library.
185
191
"""
186
- tokenizer = SentencePieceProcessor (model_file = str (args .tokenizer_path ))
192
+ try :
193
+ tokenizer = SentencePieceTokenizer (model_path = str (args .tokenizer_path ))
194
+ except Exception :
195
+ print ("Using Tiktokenizer" )
196
+ tokenizer = Tiktoken (model_path = str (args .tokenizer_path ))
187
197
188
198
# ExecuTorch Binary Evaluation
189
199
if (model := args .pte ) is not None :
0 commit comments