Skip to content

Commit 6f8d280

Browse files
committed
convert-hf : support bfloat16 conversion
1 parent 4426e29 commit 6f8d280

File tree

5 files changed

+345
-162
lines changed

5 files changed

+345
-162
lines changed

convert-hf-to-gguf.py

Lines changed: 73 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from enum import IntEnum
1313
from pathlib import Path
1414
from hashlib import sha256
15-
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast, overload
15+
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast
1616

1717
import numpy as np
1818
import torch
@@ -65,8 +65,8 @@ class Model:
6565
model_arch: gguf.MODEL_ARCH
6666

6767
def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool):
68-
if self.__class__ == Model:
69-
raise TypeError(f"{self.__class__.__name__!r} should not be directly instantiated")
68+
if type(self) == Model:
69+
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
7070
self.dir_model = dir_model
7171
self.ftype = ftype
7272
self.fname_out = fname_out
@@ -215,6 +215,23 @@ def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: i
215215
return False
216216

217217
def write_tensors(self):
218+
# same as ggml_compute_fp32_to_bf16 in ggml-impl.h
219+
def np_fp32_to_bf16(n: np.ndarray):
220+
# force nan to quiet
221+
n = np.where((n & 0x7fffffff) > 0x7f800000, n | (64 << 16), n)
222+
# flush subnormals to zero
223+
n = np.where((n & 0x7f800000) == 0, n & 0x80000000, n)
224+
# round to nearest even
225+
n = (n + (0x7fff + ((n >> 16) & 1))) >> 16
226+
return n
227+
228+
# Doing this row-wise is much, much faster than element-wise, hence the signature
229+
v_fp32_to_bf16 = np.vectorize(np_fp32_to_bf16, otypes=[np.int16], signature="(n)->(n)")
230+
if self.lazy:
231+
# TODO: find a way to implicitly wrap np.vectorize functions
232+
# NOTE: the type is changed to reflect otypes passed to np.vectorize above
233+
v_fp32_to_bf16 = gguf.LazyNumpyTensor._wrap_fn(v_fp32_to_bf16, meta_noop=np.int16)
234+
218235
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
219236

220237
for name, data_torch in self.get_tensors():
@@ -239,10 +256,7 @@ def write_tensors(self):
239256
data: np.ndarray = data # type hint
240257
n_dims = len(data.shape)
241258
data_dtype = data.dtype
242-
243-
# if f32 desired, convert any float16 to float32
244-
if self.ftype == 0 and data_dtype == np.float16:
245-
data = data.astype(np.float32)
259+
data_qtype: gguf.GGMLQuantizationType | None = None
246260

247261
# when both are True, f32 should win
248262
extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims)
@@ -254,20 +268,33 @@ def write_tensors(self):
254268
# if f16 desired, convert any float32 2-dim weight tensors to float16
255269
extra_f16 = extra_f16 or (name.endswith(".weight") and n_dims >= 2)
256270

257-
# when both extra_f32 and extra_f16 are False, convert to float32 by default
258-
if self.ftype == 1 and data_dtype == np.float16 and (extra_f32 or not extra_f16):
259-
data = data.astype(np.float32)
271+
if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
272+
if self.ftype == gguf.LlamaFileType.MOSTLY_F16:
273+
if data_dtype != np.float16:
274+
data = data.astype(np.float16)
275+
data_qtype = gguf.GGMLQuantizationType.F16
276+
277+
elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
278+
if data_dtype != np.float32:
279+
data = data.astype(np.float32)
280+
data = v_fp32_to_bf16(data.view(np.int32))
281+
assert data.dtype == np.int16
282+
data_qtype = gguf.GGMLQuantizationType.BF16
260283

261-
if self.ftype == 1 and data_dtype == np.float32 and extra_f16 and not extra_f32:
262-
data = data.astype(np.float16)
284+
else: # by default, convert to float32
285+
if data_dtype != np.float32:
286+
data = data.astype(np.float32)
287+
data_qtype = gguf.GGMLQuantizationType.F32
288+
289+
assert data_qtype is not None
263290

264291
# reverse shape to make it similar to the internal ggml dimension order
265292
shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}"
266293

267294
# n_dims is implicit in the shape
268-
logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data.dtype}, shape = {shape_str}")
295+
logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
269296

270-
self.gguf_writer.add_tensor(new_name, data)
297+
self.gguf_writer.add_tensor(new_name, data,raw_dtype=data_qtype)
271298

272299
def write(self):
273300
self.write_tensors()
@@ -2281,92 +2308,40 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
22812308

22822309

22832310
# tree of lazy tensors
2284-
class LazyTorchTensor:
2285-
_meta: Tensor
2286-
_data: Tensor | None
2287-
_args: tuple
2288-
_func: Callable[[tuple], Tensor] | None
2289-
2290-
def __init__(self, *, meta: Tensor, data: Tensor | None = None, args: tuple = (), func: Callable[[tuple], Tensor] | None = None):
2291-
self._meta = meta
2292-
self._data = data
2293-
self._args = args
2294-
self._func = func
2295-
2296-
@staticmethod
2297-
def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
2298-
# TODO: dict and set
2299-
if isinstance(o, (list, tuple)):
2300-
L = []
2301-
for item in o:
2302-
L.append(LazyTorchTensor._recurse_apply(item, fn))
2303-
if isinstance(o, tuple):
2304-
L = tuple(L)
2305-
return L
2306-
elif isinstance(o, LazyTorchTensor):
2307-
return fn(o)
2308-
else:
2309-
return o
2310-
2311-
def _wrap_fn(self, fn: Callable, use_self: bool = False) -> Callable[[Any], LazyTorchTensor]:
2312-
def wrapped_fn(*args, **kwargs):
2313-
if kwargs is None:
2314-
kwargs = {}
2315-
args = ((self,) if use_self else ()) + args
2316-
2317-
meta_args = LazyTorchTensor._recurse_apply(args, lambda t: t._meta)
2318-
2319-
return LazyTorchTensor(meta=fn(*meta_args, **kwargs), args=args, func=lambda a: fn(*a, **kwargs))
2320-
return wrapped_fn
2321-
2322-
def __getattr__(self, __name: str) -> Any:
2323-
meta_attr = getattr(self._meta, __name)
2324-
if callable(meta_attr):
2325-
return self._wrap_fn(getattr(torch.Tensor, __name), use_self=True)
2326-
elif isinstance(meta_attr, torch.Tensor):
2327-
# for things like self.T
2328-
return self._wrap_fn(lambda s: getattr(s, __name))(self)
2329-
else:
2330-
return meta_attr
2311+
class LazyTorchTensor(gguf.LazyBase):
2312+
_tensor_type = torch.Tensor
2313+
# to keep the type-checker happy
2314+
dtype: torch.dtype
2315+
shape: torch.Size
23312316

2317+
# only used when converting a torch.Tensor to a np.ndarray
23322318
_dtype_map: dict[torch.dtype, type] = {
23332319
torch.float16: np.float16,
23342320
torch.float32: np.float32,
23352321
}
23362322

2337-
def numpy(self) -> gguf.LazyTensor:
2323+
def numpy(self) -> gguf.LazyNumpyTensor:
23382324
dtype = self._dtype_map[self.dtype]
2339-
return gguf.LazyTensor(lambda: LazyTorchTensor.to_eager(self).numpy(), dtype=dtype, shape=self.shape)
2340-
2341-
@overload
2342-
@staticmethod
2343-
def to_eager(t: Tensor | LazyTorchTensor) -> Tensor: ...
2344-
2345-
@overload
2346-
@staticmethod
2347-
def to_eager(t: tuple) -> tuple: ...
2348-
2349-
@staticmethod
2350-
def to_eager(t: Any) -> Any:
2351-
def simple_to_eager(_t: LazyTorchTensor) -> Tensor:
2352-
# wake up the lazy tensor
2353-
if _t._data is None and _t._func is not None:
2354-
# recurse into its arguments
2355-
_t._args = LazyTorchTensor.to_eager(_t._args)
2356-
_t._data = _t._func(_t._args)
2357-
if _t._data is not None:
2358-
return _t._data
2359-
else:
2360-
raise ValueError(f"Could not compute lazy tensor {_t!r} with args {_t._args!r}")
2361-
2362-
# recurse into lists and/or tuples, keeping their structure
2363-
return LazyTorchTensor._recurse_apply(t, simple_to_eager)
2325+
return gguf.LazyNumpyTensor(
2326+
meta=np.lib.stride_tricks.as_strided(np.zeros(1, dtype), self.shape, (0 for _ in self.shape)),
2327+
lazy=self._lazy,
2328+
args=(self,),
2329+
func=(lambda s: s[0].numpy())
2330+
)
23642331

2365-
@staticmethod
2366-
def from_eager(t: Tensor) -> Tensor:
2367-
if (t.__class__ == LazyTorchTensor):
2332+
@classmethod
2333+
def eager_to_meta(cls, t: Tensor) -> Tensor:
2334+
if t.is_meta:
23682335
return t
2369-
return LazyTorchTensor(meta=t.detach().to("meta"), data=t) # type: ignore
2336+
return t.detach().to("meta")
2337+
2338+
@classmethod
2339+
def meta_with_dtype(cls, m: Tensor, dtype: torch.dtype) -> Tensor:
2340+
m = m.detach()
2341+
if not m.is_meta:
2342+
m = m.to("meta")
2343+
m.dtype = dtype
2344+
return m
23702345

23712346
@classmethod
23722347
def __torch_function__(cls, func, types, args=(), kwargs=None):
@@ -2377,28 +2352,8 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
23772352

23782353
if func is torch.Tensor.numpy:
23792354
return args[0].numpy()
2380-
if func is torch.equal:
2381-
eager_args = LazyTorchTensor.to_eager(args)
2382-
return func(*eager_args, **kwargs)
2383-
2384-
return LazyTorchTensor._wrap_fn(args[0], func)(*args, **kwargs)
2385-
2386-
# special methods bypass __getattr__, so they need to be added manually
2387-
# ref: https://docs.python.org/3/reference/datamodel.html#special-lookup
2388-
# NOTE: LazyTorchTensor can't be a subclass of Tensor (and then be used
2389-
# as self._meta is currently used), because then the following
2390-
# operations would by default not be wrapped, and so not propagated
2391-
# when the tensor is made eager.
2392-
# It's better to get non-silent errors for not-yet-supported operators.
2393-
# TODO: add more when needed to avoid clutter, or find a more concise way
2394-
def __neg__(self, *args): # mamba
2395-
return self._wrap_fn(torch.Tensor.__neg__)(self, *args)
2396-
2397-
def __add__(self, *args): # gemma
2398-
return self._wrap_fn(torch.Tensor.__add__)(self, *args)
23992355

2400-
def __getitem__(self, *args): # bloom falcon refact internlm2
2401-
return self._wrap_fn(torch.Tensor.__getitem__)(self, *args)
2356+
return LazyTorchTensor._wrap_fn(func)(*args, **kwargs)
24022357

24032358

24042359
def parse_args() -> argparse.Namespace:
@@ -2417,8 +2372,8 @@ def parse_args() -> argparse.Namespace:
24172372
help="path to write to; default: based on input",
24182373
)
24192374
parser.add_argument(
2420-
"--outtype", type=str, choices=["f32", "f16"], default="f16",
2421-
help="output format - use f32 for float32, f16 for float16",
2375+
"--outtype", type=str, choices=["f32", "f16", "bf16"], default="f16",
2376+
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16",
24222377
)
24232378
parser.add_argument(
24242379
"--bigendian", action="store_true",
@@ -2472,9 +2427,10 @@ def main() -> None:
24722427
logger.error(f'Error: {args.model} is not a directory')
24732428
sys.exit(1)
24742429

2475-
ftype_map = {
2476-
"f32": gguf.GGMLQuantizationType.F32,
2477-
"f16": gguf.GGMLQuantizationType.F16,
2430+
ftype_map: dict[str, gguf.LlamaFileType] = {
2431+
"f32": gguf.LlamaFileType.ALL_F32,
2432+
"f16": gguf.LlamaFileType.MOSTLY_F16,
2433+
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
24782434
}
24792435

24802436
if args.outfile is not None:

gguf-py/gguf/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .constants import *
2+
from .lazy import *
23
from .gguf_reader import *
34
from .gguf_writer import *
45
from .tensor_mapping import *

gguf-py/gguf/constants.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -820,6 +820,47 @@ class GGMLQuantizationType(IntEnum):
820820
BF16 = 30
821821

822822

823+
# TODO: add GGMLFileType from ggml_ftype in ggml.h
824+
825+
826+
# from llama_ftype in llama.h
827+
# ALL VALUES SHOULD BE THE SAME HERE AS THEY ARE OVER THERE.
828+
class LlamaFileType(IntEnum):
829+
ALL_F32 = 0
830+
MOSTLY_F16 = 1 # except 1d tensors
831+
MOSTLY_Q4_0 = 2 # except 1d tensors
832+
MOSTLY_Q4_1 = 3 # except 1d tensors
833+
MOSTLY_Q4_1_SOME_F16 = 4 # tok_embeddings.weight and output.weight are F16
834+
# MOSTLY_Q4_2 = 5 # support has been removed
835+
# MOSTLY_Q4_3 = 6 # support has been removed
836+
MOSTLY_Q8_0 = 7 # except 1d tensors
837+
MOSTLY_Q5_0 = 8 # except 1d tensors
838+
MOSTLY_Q5_1 = 9 # except 1d tensors
839+
MOSTLY_Q2_K = 10 # except 1d tensors
840+
MOSTLY_Q3_K_S = 11 # except 1d tensors
841+
MOSTLY_Q3_K_M = 12 # except 1d tensors
842+
MOSTLY_Q3_K_L = 13 # except 1d tensors
843+
MOSTLY_Q4_K_S = 14 # except 1d tensors
844+
MOSTLY_Q4_K_M = 15 # except 1d tensors
845+
MOSTLY_Q5_K_S = 16 # except 1d tensors
846+
MOSTLY_Q5_K_M = 17 # except 1d tensors
847+
MOSTLY_Q6_K = 18 # except 1d tensors
848+
MOSTLY_IQ2_XXS = 19 # except 1d tensors
849+
MOSTLY_IQ2_XS = 20 # except 1d tensors
850+
MOSTLY_Q2_K_S = 21 # except 1d tensors
851+
MOSTLY_IQ3_XS = 22 # except 1d tensors
852+
MOSTLY_IQ3_XXS = 23 # except 1d tensors
853+
MOSTLY_IQ1_S = 24 # except 1d tensors
854+
MOSTLY_IQ4_NL = 25 # except 1d tensors
855+
MOSTLY_IQ3_S = 26 # except 1d tensors
856+
MOSTLY_IQ3_M = 27 # except 1d tensors
857+
MOSTLY_IQ2_S = 28 # except 1d tensors
858+
MOSTLY_IQ2_M = 29 # except 1d tensors
859+
MOSTLY_IQ4_XS = 30 # except 1d tensors
860+
MOSTLY_IQ1_M = 31 # except 1d tensors
861+
MOSTLY_BF16 = 32 # except 1d tensors
862+
863+
823864
class GGUFEndian(IntEnum):
824865
LITTLE = 0
825866
BIG = 1

0 commit comments

Comments
 (0)