|
| 1 | +# Convert a GPTQ quantized LLaMA model to a ggml compatible file |
| 2 | +# Based on: https://github.com/qwopqwop200/GPTQ-for-LLaMa |
| 3 | +# |
| 4 | +import os |
| 5 | +import re |
| 6 | +import sys |
| 7 | +import json |
| 8 | +import struct |
| 9 | +import numpy as np |
| 10 | +import torch |
| 11 | +from sentencepiece import SentencePieceProcessor |
| 12 | + |
| 13 | +if len(sys.argv) != 4: |
| 14 | + print("Usage: convert-gptq-to-ggml.py llamaXXb-4bit.pt tokenizer.model out.bin\n") |
| 15 | + sys.exit(1) |
| 16 | + |
| 17 | +fname_model = sys.argv[1] |
| 18 | +fname_tokenizer = sys.argv[2] |
| 19 | +dir_out = sys.argv[3] |
| 20 | + |
| 21 | +model = torch.load(fname_model, map_location="cpu") |
| 22 | + |
| 23 | +n_vocab, n_embd = model['model.embed_tokens.weight'].shape |
| 24 | +n_layer = 1 + max(int(m.group(1)) for name in model |
| 25 | + if (m := re.match(r'model\.layers\.([0-9]+)', name))) |
| 26 | + |
| 27 | +# hardcoded: |
| 28 | +n_mult = 256 |
| 29 | +n_head = {32: 32, 40: 40, 60: 52, 80: 64}[n_layer] |
| 30 | + |
| 31 | +tokenizer = SentencePieceProcessor(fname_tokenizer) |
| 32 | + |
| 33 | +assert tokenizer.vocab_size() == n_vocab |
| 34 | + |
| 35 | +fname_out = sys.argv[3] |
| 36 | + |
| 37 | +fout = open(fname_out, "wb") |
| 38 | + |
| 39 | +fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex |
| 40 | +fout.write(struct.pack("i", n_vocab)) |
| 41 | +fout.write(struct.pack("i", n_embd)) |
| 42 | +fout.write(struct.pack("i", n_mult)) |
| 43 | +fout.write(struct.pack("i", n_head)) |
| 44 | +fout.write(struct.pack("i", n_layer)) |
| 45 | +fout.write(struct.pack("i", n_embd // n_head)) # rot (obsolete) |
| 46 | +fout.write(struct.pack("i", 4)) |
| 47 | + |
| 48 | + |
| 49 | +# This loop unchanged from convert-pth-to-ggml.py: |
| 50 | +for i in range(tokenizer.vocab_size()): |
| 51 | + if tokenizer.is_unknown(i): |
| 52 | + # "<unk>" token (translated as ??) |
| 53 | + text = " \u2047 ".encode("utf-8") |
| 54 | + fout.write(struct.pack("i", len(text))) |
| 55 | + fout.write(text) |
| 56 | + elif tokenizer.is_control(i): |
| 57 | + # "<s>"/"</s>" tokens |
| 58 | + fout.write(struct.pack("i", 0)) |
| 59 | + elif tokenizer.is_byte(i): |
| 60 | + # "<U+XX>" tokens (which may be invalid UTF-8) |
| 61 | + piece = tokenizer.id_to_piece(i) |
| 62 | + if len(piece) != 6: |
| 63 | + print("Invalid token: " + piece) |
| 64 | + sys.exit(1) |
| 65 | + byte_value = int(piece[3:-1], 16) |
| 66 | + fout.write(struct.pack("i", 1)) |
| 67 | + fout.write(struct.pack("B", byte_value)) |
| 68 | + else: |
| 69 | + # normal token. Uses U+2581 (LOWER ONE EIGHTH BLOCK) to represent spaces. |
| 70 | + text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8") |
| 71 | + fout.write(struct.pack("i", len(text))) |
| 72 | + fout.write(text) |
| 73 | + |
| 74 | +def write_header(shape, dst_name, ftype_cur): |
| 75 | + sname = dst_name.encode('utf-8') |
| 76 | + fout.write(struct.pack("iii", len(shape), len(sname), ftype_cur)) |
| 77 | + fout.write(struct.pack("i" * len(shape), *shape[::-1])) |
| 78 | + fout.write(sname) |
| 79 | + |
| 80 | +def convert_non_q4(src_name, dst_name): |
| 81 | + v = model[src_name] |
| 82 | + shape = v.shape |
| 83 | + print("Processing non-Q4 variable: " + src_name + " with shape: ", shape, " and type: ", v.dtype) |
| 84 | + if len(shape) == 1: |
| 85 | + print(" Converting to float32") |
| 86 | + v = v.to(torch.float32) |
| 87 | + |
| 88 | + ftype_cur = {torch.float16: 1, torch.float32: 0}[v.dtype] |
| 89 | + |
| 90 | + # header |
| 91 | + write_header(shape, dst_name, ftype_cur) |
| 92 | + |
| 93 | + # data |
| 94 | + v.numpy().tofile(fout) |
| 95 | + |
| 96 | +def convert_q4(src_name, dst_name, permute=False): |
| 97 | + zeros = model[f"{src_name}.zeros"].numpy() |
| 98 | + scales = model[f"{src_name}.scales"].numpy() |
| 99 | + bias = model[f"{src_name}.bias"].numpy() |
| 100 | + qweight = model[f"{src_name}.qweight"].numpy().T # transpose |
| 101 | + |
| 102 | + # Q4_1 does not support bias; good thing the bias is always all zeros. |
| 103 | + assert not np.any(bias) |
| 104 | + |
| 105 | + # Each int32 item is actually 8 int4 items packed together, and it's transposed. |
| 106 | + shape = (qweight.shape[0], qweight.shape[1] * 8) |
| 107 | + |
| 108 | + print("Processing Q4 variable: " + src_name + " with shape: ", shape) |
| 109 | + |
| 110 | + # The output format has the int4 weights in groups of 32 rather than 8. |
| 111 | + # It looks like this: |
| 112 | + # For each row: |
| 113 | + # For each group of 32 columns: |
| 114 | + # - addend (float32, 4 bytes) |
| 115 | + # - scale (float32, 4 bytes) |
| 116 | + # - weights (int4 * 32, 16 bytes) |
| 117 | + # Note that in the input, the scales and addends are shared between all |
| 118 | + # the columns in a row, so we end up wasting quite a bit of memory with |
| 119 | + # repeated scales and addends. |
| 120 | + |
| 121 | + addends = -zeros # flip sign |
| 122 | + |
| 123 | + # Since the output format is mixed between integers and floats, we have |
| 124 | + # to hackily view the floats as int32s just so numpy will let us |
| 125 | + # concatenate them. |
| 126 | + addends_view = addends.view(dtype=np.int32) |
| 127 | + scales_view = scales.view(dtype=np.int32) |
| 128 | + |
| 129 | + # Split into groups of 4 columns (i.e. 32 columns of quantized data): |
| 130 | + grouped = qweight.reshape([qweight.shape[0], qweight.shape[1] // 4, 4]) |
| 131 | + |
| 132 | + # Repeat addends and scales: |
| 133 | + addends_rep = np.atleast_3d(addends_view).repeat(grouped.shape[1], axis=1) |
| 134 | + scales_rep = np.atleast_3d(scales_view).repeat(grouped.shape[1], axis=1) |
| 135 | + |
| 136 | + blob = np.concatenate([scales_rep, addends_rep, grouped], axis=2, casting='no') |
| 137 | + |
| 138 | + if permute: |
| 139 | + # Permute some rows to undo the permutation done by convert_llama_weights_to_hf.py. |
| 140 | + # This can be done after the above conversion because it doesn't affect column order/layout. |
| 141 | + blob = (blob.reshape(n_head, 2, shape[0] // n_head // 2, *blob.shape[1:]) |
| 142 | + .swapaxes(1, 2) |
| 143 | + .reshape(blob.shape)) |
| 144 | + |
| 145 | + # header |
| 146 | + write_header(shape, dst_name, 3) # ftype = Q4_1 |
| 147 | + |
| 148 | + # data |
| 149 | + blob.tofile(fout) |
| 150 | + |
| 151 | +convert_non_q4("model.embed_tokens.weight", "tok_embeddings.weight") |
| 152 | +convert_non_q4("model.norm.weight", "norm.weight") |
| 153 | +convert_non_q4("lm_head.weight", "output.weight") |
| 154 | + |
| 155 | +for i in range(n_layer): |
| 156 | + convert_q4(f"model.layers.{i}.self_attn.q_proj", f"layers.{i}.attention.wq.weight", permute=True) |
| 157 | + convert_q4(f"model.layers.{i}.self_attn.k_proj", f"layers.{i}.attention.wk.weight", permute=True) |
| 158 | + convert_q4(f"model.layers.{i}.self_attn.v_proj", f"layers.{i}.attention.wv.weight") |
| 159 | + convert_q4(f"model.layers.{i}.self_attn.o_proj", f"layers.{i}.attention.wo.weight") |
| 160 | + |
| 161 | + convert_q4(f"model.layers.{i}.mlp.gate_proj", f"layers.{i}.feed_forward.w1.weight") |
| 162 | + convert_q4(f"model.layers.{i}.mlp.down_proj", f"layers.{i}.feed_forward.w2.weight") |
| 163 | + convert_q4(f"model.layers.{i}.mlp.up_proj", f"layers.{i}.feed_forward.w3.weight") |
| 164 | + |
| 165 | + convert_non_q4(f"model.layers.{i}.input_layernorm.weight", f"layers.{i}.attention_norm.weight") |
| 166 | + convert_non_q4(f"model.layers.{i}.post_attention_layernorm.weight", f"layers.{i}.ffn_norm.weight") |
| 167 | + |
| 168 | + |
| 169 | +fout.close() |
| 170 | + |
| 171 | +print("Done. Output file: " + fname_out) |
| 172 | +print("") |
0 commit comments