Skip to content

Commit 1416fd0

Browse files
committed
Add _set_vocab_rwkv_world as a common function
Signed-off-by: Molly Sophia <[email protected]>
1 parent ba47558 commit 1416fd0

File tree

1 file changed

+39
-33
lines changed

1 file changed

+39
-33
lines changed

convert_hf_to_gguf.py

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,40 @@ def _set_vocab_llama_hf(self):
902902
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
903903
special_vocab.add_to_gguf(self.gguf_writer)
904904

905+
def _set_vocab_rwkv_world(self):
906+
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
907+
vocab_size = self.hparams.get("vocab_size", 65536)
908+
909+
tokens: list[bytes] = ['<s>'.encode("utf-8")]
910+
toktypes: list[int] = [gguf.TokenType.CONTROL]
911+
912+
with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
913+
lines = f.readlines()
914+
for line in lines:
915+
parts = line.split(' ')
916+
assert len(parts) >= 3
917+
token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
918+
token = token.encode("utf-8") if isinstance(token, str) else token
919+
assert isinstance(token, bytes)
920+
assert len(token) == token_len
921+
token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff"
922+
tokens.append(token_text.encode("utf-8"))
923+
toktypes.append(gguf.TokenType.NORMAL)
924+
remainder = vocab_size - len(tokens)
925+
assert remainder >= 0
926+
for i in range(len(tokens), vocab_size):
927+
tokens.append(f"[PAD{i}]".encode("utf-8"))
928+
toktypes.append(gguf.TokenType.UNUSED)
929+
930+
self.gguf_writer.add_tokenizer_model("rwkv")
931+
self.gguf_writer.add_token_list(tokens)
932+
self.gguf_writer.add_token_types(toktypes)
933+
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
934+
special_vocab.chat_template = "rwkv-world"
935+
# hack: Add '\n\n' as the EOT token to make it chat normally
936+
special_vocab._set_special_token("eot", 261)
937+
special_vocab.add_to_gguf(self.gguf_writer)
938+
905939
def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab_size: int):
906940
tokenizer_path = Path(sys.path[0]) / "models" / f"ggml-vocab-{model_name}.gguf"
907941
logger.warning(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
@@ -3327,38 +3361,7 @@ class Rwkv6Model(Model):
33273361
model_arch = gguf.MODEL_ARCH.RWKV6
33283362

33293363
def set_vocab(self):
3330-
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
3331-
vocab_size = self.hparams.get("vocab_size", 65536)
3332-
3333-
tokens: list[bytes] = ['<s>'.encode("utf-8")]
3334-
toktypes: list[int] = [gguf.TokenType.CONTROL]
3335-
3336-
with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
3337-
lines = f.readlines()
3338-
for line in lines:
3339-
parts = line.split(' ')
3340-
assert len(parts) >= 3
3341-
token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
3342-
token = token.encode("utf-8") if isinstance(token, str) else token
3343-
assert isinstance(token, bytes)
3344-
assert len(token) == token_len
3345-
token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff"
3346-
tokens.append(token_text.encode("utf-8"))
3347-
toktypes.append(gguf.TokenType.NORMAL)
3348-
remainder = vocab_size - len(tokens)
3349-
assert remainder >= 0
3350-
for i in range(len(tokens), vocab_size):
3351-
tokens.append(f"[PAD{i}]".encode("utf-8"))
3352-
toktypes.append(gguf.TokenType.UNUSED)
3353-
3354-
self.gguf_writer.add_tokenizer_model("rwkv")
3355-
self.gguf_writer.add_token_list(tokens)
3356-
self.gguf_writer.add_token_types(toktypes)
3357-
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
3358-
special_vocab.chat_template = "rwkv-world"
3359-
# hack: Add '\n\n' as the EOT token to make it chat normally
3360-
special_vocab._set_special_token("eot", 261)
3361-
special_vocab.add_to_gguf(self.gguf_writer)
3364+
self._set_vocab_rwkv_world()
33623365

33633366
def set_gguf_parameters(self):
33643367
block_count = self.hparams["num_hidden_layers"]
@@ -3481,9 +3484,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
34813484

34823485

34833486
@Model.register("Rwkv7ForCausalLM", "RWKV7ForCausalLM")
3484-
class Rwkv7Model(Rwkv6Model):
3487+
class Rwkv7Model(Model):
34853488
model_arch = gguf.MODEL_ARCH.RWKV7
34863489

3490+
def set_vocab(self):
3491+
self._set_vocab_rwkv_world()
3492+
34873493
def calc_lora_rank(self, hidden_size, exponent, multiplier):
34883494
return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32
34893495

0 commit comments

Comments
 (0)