Skip to content

Commit d42c5c5

Browse files
mikekgfbmalfet
authored andcommitted
Tokenizer validation (#293)
* arg handling * phase ordering issue resolved * add args validation to eval, fix imports * pass model, not model_args to validate_args * read before write * fixes
1 parent 3e50c42 commit d42c5c5

File tree

4 files changed

+30
-11
lines changed

4 files changed

+30
-11
lines changed

build/builder.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,13 @@ def from_speculative_args(cls, args): # -> BuilderArgs:
124124
@dataclass
125125
class TokenizerArgs:
126126
tokenizer_path: Optional[Union[Path, str]] = None
127-
is_SentencePiece: bool = True
128-
is_TikToken: bool = False
127+
is_sentencepiece: bool = True
128+
is_tiktoken: bool = False
129129

130130
@classmethod
131131
def from_args(cls, args): # -> TokenizerArgs:
132-
is_SentencePiece = True
133-
is_TikToken = False
132+
is_sentencepiece = True
133+
is_tiktoken = False
134134

135135
if args.tokenizer_path:
136136
tokenizer_path = args.tokenizer_path
@@ -145,20 +145,20 @@ def from_args(cls, args): # -> TokenizerArgs:
145145
raise RuntimeError(f"did not find tokenizer at {tokenizer_path}")
146146

147147
if args.tiktoken:
148-
is_SentencePiece = False
149-
is_TikToken = True
148+
is_sentencepiece = False
149+
is_tiktoken = True
150150

151151
return cls(
152152
tokenizer_path=tokenizer_path,
153-
is_SentencePiece=is_SentencePiece,
154-
is_TikToken=is_TikToken,
153+
is_sentencepiece=is_sentencepiece,
154+
is_tiktoken=is_tiktoken,
155155
)
156156

157157

158158
def _initialize_tokenizer(tokenizer_args: TokenizerArgs):
159-
if tokenizer_args.is_SentencePiece:
159+
if tokenizer_args.is_sentencepiece:
160160
return SentencePieceProcessor(model_file=str(tokenizer_args.tokenizer_path))
161-
elif tokenizer_args.is_TikToken:
161+
elif tokenizer_args.is_tiktoken:
162162
return TiktokenTokenizer(model_path=str(tokenizer_args.tokenizer_path))
163163
else:
164164
raise RuntimeError("must specify a valid tokenizer in TokenizerArgs")
@@ -347,3 +347,13 @@ def _initialize_model(
347347
model.to(dtype=builder_args.precision)
348348

349349
return model
350+
351+
def tokenizer_setting_to_name(tiktoken: bool = False) -> str:
352+
return "TikToken" if tiktoken else "SentencePiece"
353+
354+
def validate_args(model: Transformer, tokenizer_args: TokenizerArgs):
355+
use_tiktoken = model.config.use_tiktoken
356+
is_tiktoken = tokenizer_args.is_tiktoken
357+
if use_tiktoken != is_tiktoken:
358+
raise RuntimeError(f"model-specified tokenizer ({tokenizer_setting_to_name(use_tiktoken)} does not match provided tokenizer ({tokenizer_setting_to_name(is_tiktoken)}")
359+

build/model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ class ModelArgs:
3737
norm_eps: float = 1e-5
3838
multiple_of: int = 256
3939
ffn_dim_multiplier: Optional[int] = None
40-
40+
use_tiktoken: bool = False
41+
4142
def __post_init__(self):
4243
if self.n_local_heads == -1:
4344
self.n_local_heads = self.n_heads
@@ -52,7 +53,10 @@ def __post_init__(self):
5253
hidden_dim = int(self.ffn_dim_multiplier * hidden_dim)
5354
self.hidden_dim = find_multiple(hidden_dim, multiple_of)
5455
self.head_dim = self.dim // self.n_heads
56+
if isinstance(self.use_tiktoken, str):
57+
self.use_tiktoken = (self.use_tiktoken == "True")
5558

59+
5660
@classmethod
5761
def from_params(cls, params_path):
5862
replace = [("rope_theta", "rope_base"), ("n_kv_heads", "n_local_heads")]
@@ -135,6 +139,7 @@ def from_name(cls, name: str):
135139
"n_layers": 32,
136140
"rope_base": 500000.0, # rope_theta
137141
"vocab_size": 128256,
142+
"use_tiktoken": True,
138143
},
139144
"Mistral-7B": {
140145
"n_layers": 32,

eval.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from build.builder import (
1515
_initialize_model,
1616
_initialize_tokenizer,
17+
validate_args,
1718
BuilderArgs,
1819
TokenizerArgs,
1920
)
@@ -238,6 +239,7 @@ def main(args) -> None:
238239
builder_args,
239240
quantize,
240241
)
242+
validate_args(model, tokenizer_args)
241243

242244
if compile:
243245
assert not (

generate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
_initialize_model,
2222
_initialize_tokenizer,
2323
_load_model,
24+
validate_args,
2425
BuilderArgs,
2526
TokenizerArgs,
2627
)
@@ -396,6 +397,7 @@ def _main(
396397

397398
builder_args.setup_caches = False
398399
model = _initialize_model(builder_args, quantize)
400+
validate_args(model, tokenizer_args)
399401

400402
# will add a version of _initialize_model in future
401403
# (need additional args)

0 commit comments

Comments
 (0)