Skip to content

Commit 58b9064

Browse files
committed
convert-hf : simplify BitNet pre-quantization
This still results in the exact same tensor weights and scales, but it reveals some weirdness in the current algorithm.
1 parent 678916f commit 58b9064

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

convert-hf-to-gguf.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,10 @@ def write_tensors(self):
263263
break
264264

265265
for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)):
266-
data: np.ndarray = data # type hint
266+
data: np.ndarray # type hint
267+
if len(data.shape) == 0:
268+
# otherwise single-value tensors get squeezed
269+
data = data.reshape((1,))
267270
n_dims = len(data.shape)
268271
data_dtype = data.dtype
269272
data_qtype: gguf.GGMLQuantizationType | None = None
@@ -334,7 +337,7 @@ def write_tensors(self):
334337
shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape
335338

336339
# reverse shape to make it similar to the internal ggml dimension order
337-
shape_str = f"{{{', '.join(str(n) for n in reversed(shape)) or '1'}}}"
340+
shape_str = f"{{{', '.join(str(n) for n in reversed(shape))}}}"
338341

339342
# n_dims is implicit in the shape
340343
logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
@@ -1442,12 +1445,13 @@ def set_gguf_parameters(self):
14421445
def weight_quant(self, weight):
14431446
dtype = weight.dtype
14441447
weight = weight.float()
1445-
s = 1 / weight.abs().mean().clamp(min=1e-5)
1446-
weight = (weight * s).round().clamp(-1, 1) / s
1447-
scale = weight.abs().max().unsqueeze(0)
1448-
weight = torch.where(weight.abs().less(1e-6), 0, weight).type(dtype)
1449-
weight = torch.sign(weight).type(dtype)
1450-
return weight.type(dtype), scale.type(torch.float32)
1448+
scale = weight.abs().mean().clamp(min=1e-5)
1449+
iscale = 1 / scale
1450+
weight = (weight * iscale).round().clamp(-1, 1)
1451+
# TODO: use the scale directly instead of inverting it twice
1452+
# (this is also unnecessarily doubly inverted upstream)
1453+
# ref: https://huggingface.co/1bitLLM/bitnet_b1_58-3B/blob/af89e318d78a70802061246bf037199d2fb97020/utils_quant.py#L10
1454+
return weight.type(dtype), (1 / iscale).type(torch.float32)
14511455

14521456
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
14531457
new_name = self.map_tensor_name(name)

0 commit comments

Comments
 (0)