You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: torchchat/utils/quantize.py
+11-3Lines changed: 11 additions & 3 deletions
Original file line number
Diff line number
Diff line change
@@ -45,6 +45,7 @@
45
45
find_multiple,
46
46
get_device_str,
47
47
get_precision,
48
+
set_precision,
48
49
name_to_dtype,
49
50
state_dict_device,
50
51
use_et_backend,
@@ -115,6 +116,13 @@ def quantize_model(
115
116
ifnotsupport_tensor_subclass:
116
117
unwrap_tensor_subclass(model)
117
118
continue
119
+
120
+
ifquantizerin ["linear:a8wxdq", "embedding:wx"]:
121
+
# These quantizers require float32 input weights. Note that after quantization,
122
+
# the weights will no longer be float32, but lowbit integers
123
+
ifget_precision() !=torch.float32:
124
+
print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.")
125
+
set_precision(torch.float32)
118
126
119
127
# We set global precision from quantize options if it is specified at cli.py:485
120
128
# so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat
0 commit comments