Skip to content

Commit efab7f8

Browse files
committed
blacked gptq-to-ggml
1 parent eec105a commit efab7f8

File tree

1 file changed

+48
-21
lines changed

1 file changed

+48
-21
lines changed

convert-gptq-to-ggml.py

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@
2020

2121
model = torch.load(fname_model, map_location="cpu")
2222

23-
n_vocab, n_embd = model['model.embed_tokens.weight'].shape
24-
n_layer = 1 + max(int(m.group(1)) for name in model
25-
if (m := re.match(r'model\.layers\.([0-9]+)', name)))
23+
n_vocab, n_embd = model["model.embed_tokens.weight"].shape
24+
n_layer = 1 + max(
25+
int(m.group(1))
26+
for name in model
27+
if (m := re.match(r"model\.layers\.([0-9]+)", name))
28+
)
2629

2730
# hardcoded:
2831
n_mult = 256
@@ -36,14 +39,14 @@
3639

3740
fout = open(fname_out, "wb")
3841

39-
fout.write(struct.pack("i", 0x67676d66)) # magic: ggmf in hex
40-
fout.write(struct.pack("i", 1)) # file version
42+
fout.write(struct.pack("i", 0x67676D66)) # magic: ggmf in hex
43+
fout.write(struct.pack("i", 1)) # file version
4144
fout.write(struct.pack("i", n_vocab))
4245
fout.write(struct.pack("i", n_embd))
4346
fout.write(struct.pack("i", n_mult))
4447
fout.write(struct.pack("i", n_head))
4548
fout.write(struct.pack("i", n_layer))
46-
fout.write(struct.pack("i", n_embd // n_head)) # rot (obsolete)
49+
fout.write(struct.pack("i", n_embd // n_head)) # rot (obsolete)
4750
fout.write(struct.pack("i", 4))
4851

4952

@@ -66,16 +69,23 @@
6669
fout.write(text)
6770
fout.write(struct.pack("f", tokenizer.get_score(i)))
6871

72+
6973
def write_header(shape, dst_name, ftype_cur):
70-
sname = dst_name.encode('utf-8')
74+
sname = dst_name.encode("utf-8")
7175
fout.write(struct.pack("iii", len(shape), len(sname), ftype_cur))
7276
fout.write(struct.pack("i" * len(shape), *shape[::-1]))
7377
fout.write(sname)
7478

79+
7580
def convert_non_q4(src_name, dst_name):
7681
v = model[src_name]
7782
shape = v.shape
78-
print("Processing non-Q4 variable: " + src_name + " with shape: ", shape, " and type: ", v.dtype)
83+
print(
84+
"Processing non-Q4 variable: " + src_name + " with shape: ",
85+
shape,
86+
" and type: ",
87+
v.dtype,
88+
)
7989
if len(shape) == 1:
8090
print(" Converting to float32")
8191
v = v.to(torch.float32)
@@ -88,11 +98,12 @@ def convert_non_q4(src_name, dst_name):
8898
# data
8999
v.numpy().tofile(fout)
90100

101+
91102
def convert_q4(src_name, dst_name, permute=False):
92103
zeros = model[f"{src_name}.zeros"].numpy()
93104
scales = model[f"{src_name}.scales"].numpy()
94105
bias = model[f"{src_name}.bias"].numpy()
95-
qweight = model[f"{src_name}.qweight"].numpy().T # transpose
106+
qweight = model[f"{src_name}.qweight"].numpy().T # transpose
96107

97108
# Q4_1 does not support bias; good thing the bias is always all zeros.
98109
assert not np.any(bias)
@@ -113,7 +124,7 @@ def convert_q4(src_name, dst_name, permute=False):
113124
# the columns in a row, so we end up wasting quite a bit of memory with
114125
# repeated scales and addends.
115126

116-
addends = -zeros # flip sign
127+
addends = -zeros # flip sign
117128

118129
# Since the output format is mixed between integers and floats, we have
119130
# to hackily view the floats as int32s just so numpy will let us
@@ -128,37 +139,53 @@ def convert_q4(src_name, dst_name, permute=False):
128139
addends_rep = np.atleast_3d(addends_view).repeat(grouped.shape[1], axis=1)
129140
scales_rep = np.atleast_3d(scales_view).repeat(grouped.shape[1], axis=1)
130141

131-
blob = np.concatenate([scales_rep, addends_rep, grouped], axis=2, casting='no')
142+
blob = np.concatenate([scales_rep, addends_rep, grouped], axis=2, casting="no")
132143

133144
if permute:
134145
# Permute some rows to undo the permutation done by convert_llama_weights_to_hf.py.
135146
# This can be done after the above conversion because it doesn't affect column order/layout.
136-
blob = (blob.reshape(n_head, 2, shape[0] // n_head // 2, *blob.shape[1:])
137-
.swapaxes(1, 2)
138-
.reshape(blob.shape))
147+
blob = (
148+
blob.reshape(n_head, 2, shape[0] // n_head // 2, *blob.shape[1:])
149+
.swapaxes(1, 2)
150+
.reshape(blob.shape)
151+
)
139152

140153
# header
141-
write_header(shape, dst_name, 3) # ftype = Q4_1
154+
write_header(shape, dst_name, 3) # ftype = Q4_1
142155

143156
# data
144157
blob.tofile(fout)
145158

159+
146160
convert_non_q4("model.embed_tokens.weight", "tok_embeddings.weight")
147161
convert_non_q4("model.norm.weight", "norm.weight")
148162
convert_non_q4("lm_head.weight", "output.weight")
149163

150164
for i in range(n_layer):
151-
convert_q4(f"model.layers.{i}.self_attn.q_proj", f"layers.{i}.attention.wq.weight", permute=True)
152-
convert_q4(f"model.layers.{i}.self_attn.k_proj", f"layers.{i}.attention.wk.weight", permute=True)
165+
convert_q4(
166+
f"model.layers.{i}.self_attn.q_proj",
167+
f"layers.{i}.attention.wq.weight",
168+
permute=True,
169+
)
170+
convert_q4(
171+
f"model.layers.{i}.self_attn.k_proj",
172+
f"layers.{i}.attention.wk.weight",
173+
permute=True,
174+
)
153175
convert_q4(f"model.layers.{i}.self_attn.v_proj", f"layers.{i}.attention.wv.weight")
154176
convert_q4(f"model.layers.{i}.self_attn.o_proj", f"layers.{i}.attention.wo.weight")
155177

156178
convert_q4(f"model.layers.{i}.mlp.gate_proj", f"layers.{i}.feed_forward.w1.weight")
157179
convert_q4(f"model.layers.{i}.mlp.down_proj", f"layers.{i}.feed_forward.w2.weight")
158-
convert_q4(f"model.layers.{i}.mlp.up_proj", f"layers.{i}.feed_forward.w3.weight")
159-
160-
convert_non_q4(f"model.layers.{i}.input_layernorm.weight", f"layers.{i}.attention_norm.weight")
161-
convert_non_q4(f"model.layers.{i}.post_attention_layernorm.weight", f"layers.{i}.ffn_norm.weight")
180+
convert_q4(f"model.layers.{i}.mlp.up_proj", f"layers.{i}.feed_forward.w3.weight")
181+
182+
convert_non_q4(
183+
f"model.layers.{i}.input_layernorm.weight", f"layers.{i}.attention_norm.weight"
184+
)
185+
convert_non_q4(
186+
f"model.layers.{i}.post_attention_layernorm.weight",
187+
f"layers.{i}.ffn_norm.weight",
188+
)
162189

163190

164191
fout.close()

0 commit comments

Comments
 (0)