Skip to content

Commit 143b4b6

Browse files
committed
wip: 1.625 bpw ternary packing scheme
1 parent 89c7e4c commit 143b4b6

File tree

11 files changed

+495
-49
lines changed

11 files changed

+495
-49
lines changed

convert-hf-to-gguf.py

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -294,12 +294,27 @@ def write_tensors(self):
294294
))
295295

296296
if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
297-
if self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
297+
if self.ftype == gguf.LlamaFileType.MOSTLY_Q1_3 and not any(
298+
self.match_model_tensor_name(new_name, key, None)
299+
for key in [
300+
gguf.MODEL_TENSOR.TOKEN_EMBD,
301+
gguf.MODEL_TENSOR.OUTPUT,
302+
]
303+
):
304+
data = gguf.quantize_q1_3(data)
305+
assert data.dtype == np.uint8
306+
data_qtype = gguf.GGMLQuantizationType.Q1_3
307+
308+
elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
298309
data = gguf.quantize_bf16(data)
299310
assert data.dtype == np.int16
300311
data_qtype = gguf.GGMLQuantizationType.BF16
301312

302-
elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 and gguf.can_quantize_to_q8_0(data):
313+
elif (
314+
self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0
315+
or self.ftype == gguf.LlamaFileType.MOSTLY_Q1_3
316+
and gguf.can_quantize_to_q8_0(data)
317+
):
303318
data = gguf.quantize_q8_0(data)
304319
assert data.dtype == np.uint8
305320
data_qtype = gguf.GGMLQuantizationType.Q8_0
@@ -1401,6 +1416,12 @@ def write_tensors(self):
14011416
class BitnetModel(Model):
14021417
model_arch = gguf.MODEL_ARCH.BITNET
14031418

1419+
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool, model_name: str | None):
1420+
if ftype == gguf.LlamaFileType.GUESSED:
1421+
ftype = gguf.LlamaFileType.MOSTLY_Q1_3
1422+
1423+
super().__init__(dir_model, ftype, fname_out, is_big_endian, use_temp_file, eager, model_name)
1424+
14041425
def set_vocab(self):
14051426
self._set_vocab_sentencepiece()
14061427

@@ -1420,40 +1441,24 @@ def weight_quant(self, weight):
14201441
return weight.type(dtype), scale.type(torch.float32)
14211442

14221443
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
1423-
# transform weight into 1/0/-1 (in fp32)
1424-
if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight",
1425-
"down_proj.weight", "up_proj.weight", "gate_proj.weight",
1426-
"o_proj.weight")):
1427-
weight_torch, scale_torch = self.weight_quant(data_torch)
1428-
1429-
tensors: list[tuple[str, Tensor]] = []
14301444

1431-
if name.endswith("q_proj.weight"):
1432-
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), weight_torch))
1433-
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid, suffix=".scale"), scale_torch))
1434-
elif name.endswith("k_proj.weight"):
1435-
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), weight_torch))
1436-
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid, suffix=".scale"), scale_torch))
1437-
elif name.endswith("v_proj.weight"):
1438-
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), weight_torch))
1439-
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid, suffix=".scale"), scale_torch))
1440-
elif name.endswith("o_proj.weight"):
1441-
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid), weight_torch))
1442-
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid, suffix=".scale"), scale_torch))
1443-
elif name.endswith("up_proj.weight"):
1444-
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), weight_torch))
1445-
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid, suffix=".scale"), scale_torch))
1446-
elif name.endswith("down_proj.weight"):
1447-
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid), weight_torch))
1448-
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid, suffix=".scale"), scale_torch))
1449-
elif name.endswith("gate_proj.weight"):
1450-
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), weight_torch))
1451-
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid, suffix=".scale"), scale_torch))
1452-
1453-
if len(tensors) == 0:
1454-
tensors.append((self.map_tensor_name(name), data_torch))
1445+
new_name = self.map_tensor_name(name)
14551446

1456-
return tensors
1447+
if any(self.match_model_tensor_name(new_name, key, bid) for key in [
1448+
gguf.MODEL_TENSOR.ATTN_Q,
1449+
gguf.MODEL_TENSOR.ATTN_K,
1450+
gguf.MODEL_TENSOR.ATTN_V,
1451+
gguf.MODEL_TENSOR.ATTN_OUT,
1452+
gguf.MODEL_TENSOR.FFN_UP,
1453+
gguf.MODEL_TENSOR.FFN_DOWN,
1454+
gguf.MODEL_TENSOR.FFN_GATE,
1455+
]):
1456+
# transform weight into 1/0/-1 (in fp32)
1457+
weight_torch, scale_torch = self.weight_quant(data_torch)
1458+
yield (new_name, weight_torch)
1459+
yield (new_name.removesuffix(".weight") + ".scale", scale_torch)
1460+
else:
1461+
yield (new_name, data_torch)
14571462

14581463

14591464
@Model.register("GrokForCausalLM")

examples/quantize/quantize.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
2626
{ "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", },
2727
{ "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", },
2828
{ "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", },
29+
{ "Q1_3", LLAMA_FTYPE_MOSTLY_Q1_3, " 1.63 bpw for BitNet 1.58b" },
2930
{ "Q2_2", LLAMA_FTYPE_MOSTLY_Q2_2, " 2 bpw quantization", },
3031
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
3132
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },

ggml-common.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,14 @@ typedef sycl::half2 ggml_half2;
137137

138138
#endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP
139139

140+
// 1.625 bpw for BitNet 1.58b models
141+
#define QK1_3 64
142+
typedef struct {
143+
uint8_t q[(QK1_3 - 4*QK1_3/64)/5]; // 5 elements per byte (3^5 = 243 < 256)
144+
uint8_t qs[QK1_3/64]; // 4 elements per byte
145+
} block_q1_3;
146+
static_assert(sizeof(block_q1_3) == (QK1_3 - 4*QK1_3/64)/5 + QK1_3/64, "wrong q1_3 block size/padding");
147+
140148
#define QK2_2 32
141149
typedef struct {
142150
uint8_t qs[QK2_2 / 4]; // nibbles / quants
@@ -339,6 +347,7 @@ typedef struct {
339347
} block_iq3_s;
340348
static_assert(sizeof(block_iq3_s) == sizeof(ggml_half) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding");
341349

350+
// 1.5625 bpw
342351
typedef struct {
343352
ggml_half d;
344353
uint8_t qs[QK_K/8];
@@ -1095,6 +1104,41 @@ GGML_TABLE_BEGIN(uint32_t, q22_grid, 256)
10951104
0x00ffffff, 0x01ffffff, 0x00ffffff, 0xffffffff,
10961105
GGML_TABLE_END()
10971106

1107+
GGML_TABLE_BEGIN(uint32_t, q1_3_grid, 256)
1108+
0xffffffff, 0xffffffff, 0xffffff00, 0xffffff01, 0xffff00ff, 0xffff0000, 0xffff0001, 0xffff01ff,
1109+
0xffff0100, 0xffff0101, 0xff00ffff, 0xff00ff00, 0xff00ff01, 0xff0000ff, 0xff000000, 0xff000001,
1110+
0xff0001ff, 0xff000100, 0xff000101, 0xff01ffff, 0xff01ffff, 0xff01ff00, 0xff01ff01, 0xff0100ff,
1111+
0xff010000, 0xff010001, 0xff0101ff, 0xff010100, 0xff010101, 0x00ffffff, 0x00ffff00, 0x00ffff01,
1112+
0x00ff00ff, 0x00ff0000, 0x00ff0001, 0x00ff01ff, 0x00ff0100, 0x00ff0101, 0x0000ffff, 0x0000ff00,
1113+
0x0000ff00, 0x0000ff01, 0x000000ff, 0x00000000, 0x00000001, 0x000001ff, 0x00000100, 0x00000101,
1114+
0x0001ffff, 0x0001ff00, 0x0001ff01, 0x000100ff, 0x00010000, 0x00010001, 0x000101ff, 0x00010100,
1115+
0x00010101, 0x01ffffff, 0x01ffff00, 0x01ffff01, 0x01ffff01, 0x01ff00ff, 0x01ff0000, 0x01ff0001,
1116+
0x01ff01ff, 0x01ff0100, 0x01ff0101, 0x0100ffff, 0x0100ff00, 0x0100ff01, 0x010000ff, 0x01000000,
1117+
0x01000001, 0x010001ff, 0x01000100, 0x01000101, 0x0101ffff, 0x0101ff00, 0x0101ff01, 0x0101ff01,
1118+
0x010100ff, 0x01010000, 0x01010001, 0x010101ff, 0x01010100, 0x01010101, 0xffffffff, 0xffffff00,
1119+
0xffffff01, 0xffff00ff, 0xffff0000, 0xffff0001, 0xffff01ff, 0xffff0100, 0xffff0101, 0xff00ffff,
1120+
0xff00ff00, 0xff00ff01, 0xff0000ff, 0xff0000ff, 0xff000000, 0xff000001, 0xff0001ff, 0xff000100,
1121+
0xff000101, 0xff01ffff, 0xff01ff00, 0xff01ff01, 0xff0100ff, 0xff010000, 0xff010001, 0xff0101ff,
1122+
0xff010100, 0xff010101, 0x00ffffff, 0x00ffff00, 0x00ffff01, 0x00ff00ff, 0x00ff0000, 0x00ff0000,
1123+
0x00ff0001, 0x00ff01ff, 0x00ff0100, 0x00ff0101, 0x0000ffff, 0x0000ff00, 0x0000ff01, 0x000000ff,
1124+
0x00000000, 0x00000001, 0x000001ff, 0x00000100, 0x00000101, 0x0001ffff, 0x0001ff00, 0x0001ff01,
1125+
0x000100ff, 0x00010000, 0x00010000, 0x00010001, 0x000101ff, 0x00010100, 0x00010101, 0x01ffffff,
1126+
0x01ffff00, 0x01ffff01, 0x01ff00ff, 0x01ff0000, 0x01ff0001, 0x01ff01ff, 0x01ff0100, 0x01ff0101,
1127+
0x0100ffff, 0x0100ff00, 0x0100ff01, 0x010000ff, 0x01000000, 0x01000001, 0x01000001, 0x010001ff,
1128+
0x01000100, 0x01000101, 0x0101ffff, 0x0101ff00, 0x0101ff01, 0x010100ff, 0x01010000, 0x01010001,
1129+
0x010101ff, 0x01010100, 0x01010101, 0xffffffff, 0xffffff00, 0xffffff01, 0xffff00ff, 0xffff0000,
1130+
0xffff0001, 0xffff01ff, 0xffff01ff, 0xffff0100, 0xffff0101, 0xff00ffff, 0xff00ff00, 0xff00ff01,
1131+
0xff0000ff, 0xff000000, 0xff000001, 0xff0001ff, 0xff000100, 0xff000101, 0xff01ffff, 0xff01ff00,
1132+
0xff01ff01, 0xff0100ff, 0xff010000, 0xff010001, 0xff0101ff, 0xff0101ff, 0xff010100, 0xff010101,
1133+
0x00ffffff, 0x00ffff00, 0x00ffff01, 0x00ff00ff, 0x00ff0000, 0x00ff0001, 0x00ff01ff, 0x00ff0100,
1134+
0x00ff0101, 0x0000ffff, 0x0000ff00, 0x0000ff01, 0x000000ff, 0x00000000, 0x00000001, 0x000001ff,
1135+
0x00000100, 0x00000100, 0x00000101, 0x0001ffff, 0x0001ff00, 0x0001ff01, 0x000100ff, 0x00010000,
1136+
0x00010001, 0x000101ff, 0x00010100, 0x00010101, 0x01ffffff, 0x01ffff00, 0x01ffff01, 0x01ff00ff,
1137+
0x01ff0000, 0x01ff0001, 0x01ff01ff, 0x01ff0100, 0x01ff0101, 0x01ff0101, 0x0100ffff, 0x0100ff00,
1138+
0x0100ff01, 0x010000ff, 0x01000000, 0x01000001, 0x010001ff, 0x01000100, 0x01000101, 0x0101ffff,
1139+
0x0101ff00, 0x0101ff01, 0x010100ff, 0x01010000, 0x01010001, 0x010101ff, 0x01010100, 0x01010101,
1140+
GGML_TABLE_END()
1141+
10981142
#define NGRID_IQ1S 2048
10991143
#define IQ1S_DELTA 0.125f
11001144
#define IQ1M_DELTA 0.125f

0 commit comments

Comments
 (0)