Skip to content

Commit e33a233

Browse files
author
Jack Zhang
committed
Fix scheduler bf16/fp16 mix error
1 parent c2ced76 commit e33a233

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

quantization/quantize.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch
1616
import torch.nn as nn
1717
import torch.nn.functional as F
18-
from build.utils import get_device_str, name_to_dtype, state_dict_device
18+
from build.utils import get_device_str, get_precision, name_to_dtype, state_dict_device
1919

2020
from quantization.qops import LinearInt8 as WeightOnlyInt8Linear, QuantizedEmbedding
2121
# AttributeError: '_OpNamespace' 'quantized_decomposed' object has no attribute 'quantize_per_channel_group'
@@ -51,8 +51,12 @@ def quantize_model(model: nn.Module, device, quantize_options, tokenizer=None):
5151
):
5252
raise RuntimeError(f"unknown quantizer {quantizer} specified")
5353
if quantizer in ao_quantizer_class_dict:
54-
dtype = quantize_options.get("precision", {}).get("dtype", "float16")
55-
precision = name_to_dtype(dtype, device)
54+
# Use dtype precision specified in configuration, else fallback on global precision.
55+
if "precision" in quantize_options:
56+
dtype = quantize_options.get("precision", {}).get("dtype", name_to_dtype(get_precision()))
57+
precision = name_to_dtype(dtype, device)
58+
else:
59+
precision = get_precision()
5660
try:
5761
# Easier to ask forgiveness than permission
5862
quant_handler = ao_quantizer_class_dict[quantizer](

0 commit comments

Comments
 (0)