Skip to content

Commit a41b63c

Browse files
committed
fix: some tensors are in bf16
numpy does not support bf16, so we conservatively upcast to f32
1 parent 51b53d4 commit a41b63c

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

models/convert.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def quantize_q5_1(x):
101101
def quantize_q8_0(x):
102102
assert x.shape[-1] % QK8_0 == 0 and x.shape[-1] > QK8_0
103103
x = x.reshape(-1, QK8_0)
104-
amax = np.max(np.abs(x), axis=-1, keepdims=True)
104+
amax = np.max(np.abs(x), axis=-1, keepdims=True)
105105
d = amax / ((1 << 7) - 1)
106106
qs = (x / d).round().clip(min=-128, max=127).astype(np.int8)
107107
d = d.astype(np.float16).view(np.int8)
@@ -178,7 +178,7 @@ def preprocess(state_dict):
178178
print("no alphas_cumprod in file, generate new one")
179179
alphas_cumprod = get_alpha_comprod()
180180
state_dict["alphas_cumprod"] = alphas_cumprod
181-
181+
182182
new_state_dict = {}
183183
for name, w in state_dict.items():
184184
# ignore unused tensors
@@ -251,7 +251,7 @@ def preprocess(state_dict):
251251
new_state_dict[new_name] = w
252252
print(f"preprocess {name} => {new_name}")
253253
continue
254-
254+
255255
# convert unet transformer linear to conv2d 1x1
256256
if name.startswith("model.diffusion_model.") and (name.endswith("proj_in.weight") or name.endswith("proj_out.weight")):
257257
if len(w.shape) == 2:
@@ -421,7 +421,13 @@ def convert(model_path, out_type = None, out_file=None, lora=False):
421421
continue
422422
if name in unused_tensors:
423423
continue
424-
data = state_dict[name].numpy()
424+
425+
data_tmp = state_dict[name]
426+
if data_tmp.dtype == torch.bfloat16:
427+
# numpy does not support bf16, so we conservatively upcast to f32
428+
data = data_tmp.float().numpy()
429+
else:
430+
data = data_tmp.numpy()
425431

426432
n_dims = len(data.shape)
427433
shape = data.shape
@@ -452,7 +458,7 @@ def convert(model_path, out_type = None, out_file=None, lora=False):
452458
else:
453459
data = data.astype(np.float32)
454460
ttype = "f32"
455-
461+
456462
print("Processing tensor: {} with shape {}, {} -> {}".format(name, data.shape, old_type, ttype))
457463

458464
# header

0 commit comments

Comments
 (0)