Skip to content

Commit 20a1a4e

Browse files
R2D2FISHggerganov
andauthored
Fix GPTQ converter (ggml-org#423)
* Fix GPTQ converter * Fix comment --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent ad072fc commit 20a1a4e

File tree

1 file changed

+8
-13
lines changed

1 file changed

+8
-13
lines changed

convert-gptq-to-ggml.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@
3636

3737
fout = open(fname_out, "wb")
3838

39-
fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
39+
fout.write(struct.pack("i", 0x67676d66)) # magic: ggmf in hex
40+
fout.write(struct.pack("i", 1)) # file version
4041
fout.write(struct.pack("i", n_vocab))
4142
fout.write(struct.pack("i", n_embd))
4243
fout.write(struct.pack("i", n_mult))
@@ -49,27 +50,21 @@
4950
# This loop unchanged from convert-pth-to-ggml.py:
5051
for i in range(tokenizer.vocab_size()):
5152
if tokenizer.is_unknown(i):
52-
# "<unk>" token (translated as ??)
5353
text = " \u2047 ".encode("utf-8")
54-
fout.write(struct.pack("i", len(text)))
55-
fout.write(text)
5654
elif tokenizer.is_control(i):
57-
# "<s>"/"</s>" tokens
58-
fout.write(struct.pack("i", 0))
55+
text = b""
5956
elif tokenizer.is_byte(i):
60-
# "<U+XX>" tokens (which may be invalid UTF-8)
6157
piece = tokenizer.id_to_piece(i)
6258
if len(piece) != 6:
63-
print("Invalid token: " + piece)
59+
print(f"Invalid token: {piece}")
6460
sys.exit(1)
6561
byte_value = int(piece[3:-1], 16)
66-
fout.write(struct.pack("i", 1))
67-
fout.write(struct.pack("B", byte_value))
62+
text = struct.pack("B", byte_value)
6863
else:
69-
# normal token. Uses U+2581 (LOWER ONE EIGHTH BLOCK) to represent spaces.
7064
text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8")
71-
fout.write(struct.pack("i", len(text)))
72-
fout.write(text)
65+
fout.write(struct.pack("i", len(text)))
66+
fout.write(text)
67+
fout.write(struct.pack("f", tokenizer.get_score(i)))
7368

7469
def write_header(shape, dst_name, ftype_cur):
7570
sname = dst_name.encode('utf-8')

0 commit comments

Comments
 (0)