Skip to content

Commit 6f5a6c3

Browse files
committed
merge conflict
1 parent e3165a7 commit 6f5a6c3

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

torchchat/utils/quantize.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
find_multiple,
4646
get_device_str,
4747
get_precision,
48+
set_precision,
4849
name_to_dtype,
4950
state_dict_device,
5051
use_et_backend,
@@ -115,6 +116,13 @@ def quantize_model(
115116
if not support_tensor_subclass:
116117
unwrap_tensor_subclass(model)
117118
continue
119+
120+
if quantizer in ["linear:a8wxdq", "embedding:wx"]:
121+
# These quantizers require float32 input weights. Note that after quantization,
122+
# the weights will no longer be float32, but lowbit integers
123+
if get_precision() != torch.float32:
124+
print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.")
125+
set_precision(torch.float32)
118126

119127
# We set global precision from quantize options if it is specified at cli.py:485
120128
# so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat
@@ -909,8 +917,8 @@ def quantized_model(self) -> nn.Module:
909917
IntxWeightEmbeddingQuantizer,
910918
)
911919

912-
ao_quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightLinearQuantizer
913-
ao_quantizer_class_dict["embedding:wx"] = IntxWeightEmbeddingQuantizer
920+
quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightLinearQuantizer
921+
quantizer_class_dict["embedding:wx"] = IntxWeightEmbeddingQuantizer
914922

915923
# Try loading custom op
916924
try:
@@ -929,6 +937,6 @@ def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None
929937
global torchao_experimental_load_error
930938
raise Exception(f"Note: Failed to load torchao experimental quantizer with error: {torchao_experimental_load_error}")
931939

932-
a8wxdq_load_error = e
940+
torchao_experimental_load_error = e
933941
quantizer_class_dict["linear:a8wxdq"] = ErrorHandler
934942
quantizer_class_dict["embedding:wx"] = ErrorHandler

0 commit comments

Comments
 (0)