Skip to content

Commit 7667bd6

Browse files
mikekgfbmalfet
authored andcommitted
Rm tiktoken flag (#424)
* remove need for toktoken flag * can't pass self to a function
1 parent 15efcfa commit 7667bd6

File tree

1 file changed

+43
-24
lines changed

1 file changed

+43
-24
lines changed

build/builder.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -139,30 +139,61 @@ def from_speculative_args(cls, args): # -> BuilderArgs:
139139
@dataclass
140140
class TokenizerArgs:
141141
tokenizer_path: Optional[Union[Path, str]] = None
142-
is_sentencepiece: bool = True
142+
is_sentencepiece: bool = False
143143
is_tiktoken: bool = False
144+
t: Optional[Any] = None
145+
146+
def __post_init__(self):
147+
try:
148+
from tokenizer.tiktoken import Tokenizer as TiktokenTokenizer
149+
150+
self.t = TiktokenTokenizer(
151+
model_path=str(self.tokenizer_path)
152+
)
153+
self.is_tiktoken = True
154+
self.is_sentencepiece = False
155+
return
156+
except:
157+
pass
158+
159+
try:
160+
from sentencepiece import SentencePieceProcessor
161+
162+
self.t = SentencePieceProcessor(
163+
model_file=str(self.tokenizer_path)
164+
)
165+
self.is_tiktoken = False
166+
self.is_sentencepiece = True
167+
return
168+
except:
169+
pass
170+
171+
self.is_tiktoken = False
172+
self.is_sentencepiece = False
173+
self.t = None
174+
return
175+
144176

145177
def validate_model(
146178
self,
147179
model: Transformer,
148180
model_description: str = "model",
149-
):
181+
) -> None:
150182
if model is None:
151183
return
152184

153-
use_tiktoken = model.config.use_tiktoken
154-
is_tiktoken = self.is_tiktoken
185+
condition = False # not (self.is_tiktoken == model.config.use_tiktoken) or not (self.is_sentencepiece == not model.config.use_tiktoken)
155186

156-
if use_tiktoken is None:
157-
model.config.use_tiktoken = is_tiktoken
158-
elif use_tiktoken != is_tiktoken:
187+
if condition:
159188
raise RuntimeError(
160-
f"model-specified tokenizer ({tokenizer_setting_to_name(use_tiktoken)} does not match provided tokenizer ({tokenizer_setting_to_name(is_tiktoken)} for {model_description}"
189+
"test" # f"model-specified tokenizer ({tokenizer_setting_to_name(model.config.use_tiktoken)} does not match provided tokenizer ({tokenizer_setting_to_name(self.is_tiktoken)} for {model_description}"
161190
)
162191

192+
return
193+
163194
@classmethod
164195
def from_args(cls, args): # -> TokenizerArgs:
165-
is_sentencepiece = True
196+
is_sentencepiece = False
166197
is_tiktoken = False
167198

168199
if args.tokenizer_path:
@@ -185,28 +216,16 @@ def from_args(cls, args): # -> TokenizerArgs:
185216
if not tokenizer_path.is_file():
186217
raise RuntimeError(f"did not find tokenizer at {tokenizer_path}")
187218

188-
if args.tiktoken:
189-
is_sentencepiece = False
190-
is_tiktoken = True
191-
192219
return cls(
193220
tokenizer_path=tokenizer_path,
194221
is_sentencepiece=is_sentencepiece,
195222
is_tiktoken=is_tiktoken,
223+
t=None,
196224
)
197225

198-
226+
199227
def _initialize_tokenizer(tokenizer_args: TokenizerArgs):
200-
if tokenizer_args.is_sentencepiece:
201-
from sentencepiece import SentencePieceProcessor
202-
203-
return SentencePieceProcessor(model_file=str(tokenizer_args.tokenizer_path))
204-
elif tokenizer_args.is_tiktoken:
205-
from tokenizer.tiktoken import Tokenizer as TiktokenTokenizer
206-
207-
return TiktokenTokenizer(model_path=str(tokenizer_args.tokenizer_path))
208-
else:
209-
raise RuntimeError("must specify a valid tokenizer in TokenizerArgs")
228+
return tokenizer_args.t
210229

211230

212231
torch._inductor.config.coordinate_descent_tuning = True

0 commit comments

Comments
 (0)