Skip to content

Commit 9fcc489

Browse files
GPTQ: use optimum format by default (#2568)
* Use HuggingFace Optimum format for GPTQ checkpoint * Fix issue in LLM examples --------- Co-authored-by: WeizhuoZhang-intel <[email protected]>
1 parent cfaa7a2 commit 9fcc489

File tree

7 files changed

+178
-27
lines changed

7 files changed

+178
-27
lines changed

examples/cpu/inference/python/llm/run.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,14 @@ def main(args_in: Optional[List[str]] = None) -> None:
130130
"--gptq",
131131
action="store_true",
132132
help="Run GPTQ calibration to generate optimized INT4 weight for weight-only quantization."
133-
"This is recommended for INT4 to minimize accuracy drop after quantization."
133+
" This is recommended for INT4 to minimize accuracy drop after quantization."
134+
)
135+
parser.add_argument(
136+
"--gptq-legacy-format",
137+
action="store_true",
138+
help="Indicate that the low-precision checkpoint is in the legacy format rather than the"
139+
" HuggingFace Optimum format for backward compatibility. It must be used with"
140+
" --low-precision-checkpoint. Otherwise, it has no effect."
134141
)
135142
parser.add_argument(
136143
"--group-size",
@@ -357,6 +364,8 @@ def main(args_in: Optional[List[str]] = None) -> None:
357364
str(args.low_precision_checkpoint),
358365
]
359366
)
367+
if args.gptq_legacy_format:
368+
quant_cmd.extend(["--gptq-legacy-format"])
360369
else:
361370
# No need to set group size if args.gptq is true
362371
# Group size is read from the checkpoint

examples/cpu/inference/python/llm/single_instance/run_quantization.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,13 @@
149149
"PER_BATCH_IC_BLOCK(3): quantize per block of size 1 x IC_BLOCK. "
150150
"IC_BLOCK is determined by IC automatically.",
151151
)
152+
parser.add_argument(
153+
"--gptq-legacy-format",
154+
action="store_true",
155+
help="Indicate that the low-precision checkpoint is in the legacy format rather than the"
156+
" HuggingFace Optimum format for backward compatibility. It must be used with"
157+
" --low-precision-checkpoint. Otherwise, it has no effect."
158+
)
152159
args = parser.parse_args()
153160

154161

@@ -605,15 +612,11 @@ def calib_func(prepared_model):
605612
)
606613
if args.low_precision_checkpoint != "":
607614
low_precision_checkpoint = torch.load(args.low_precision_checkpoint)
608-
config_dict = {
609-
"weight_key": "qweight",
610-
"scale_key": "scales",
611-
"zero_point_key": "qzeros",
612-
"bias_key": "bias",
613-
"g_idx_key": "g_idx"
614-
}
615-
state_dict_and_config = (low_precision_checkpoint, config_dict)
616-
low_precision_checkpoint = state_dict_and_config
615+
if args.gptq_legacy_format:
616+
config_dict = (
617+
ipex.utils.weight_only_quantization._legacy_lowp_checkpoint_config()
618+
)
619+
low_precision_checkpoint = (low_precision_checkpoint, config_dict)
617620
else:
618621
low_precision_checkpoint = None
619622
user_model = ipex.llm.optimize(

intel_extension_for_pytorch/quantization/_GPTQ/_gptq_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def gptq_export(
112112
compression_dim=compression_dim,
113113
scale_dtype=scale_dtype,
114114
device=torch.device("cpu"),
115-
use_optimum_format=False,
115+
use_optimum_format=True,
116116
)
117117
new_module.pack(int_weight, gptq_scale, gptq_zp, m.bias, gptq_perm)
118118
set_module(model, k, new_module)

intel_extension_for_pytorch/transformers/models/cpu/fusions/linear_fusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def __init__(self, module, tpp=False, woq=False):
304304
concat_weight = torch.concat(weights_list, 0)
305305
concat_scales = torch.concat(scales_list, 0)
306306
concat_zeros = torch.concat(zeros_list, 0)
307-
use_bias = all(bias_list)
307+
use_bias = all([b is not None for b in bias_list])
308308
concat_bias = torch.concat(bias_list, 0) if use_bias else None
309309
mod = nn.Linear(
310310
concat_weight.shape[1], concat_weight.shape[0], use_bias

intel_extension_for_pytorch/utils/weight_only_quantization.py

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,18 @@
77
# Weight shape is N by K if transposed is False otherwise K by N.
88
# Bias is optional. If bias is not provided in the checkpoint, we read the original model.
99
DEFAULT_LOWP_CHECKPOINT_CONFIG = {
10-
"name": "default",
10+
"name": "optimum",
11+
"use_optimum_format": True,
12+
"weight_key": "qweight",
13+
"scale_key": "scales",
14+
"zero_point_key": "qzeros",
15+
"bias_key": "bias",
16+
"g_idx_key": "g_idx",
17+
}
18+
19+
LEGACY_LOWP_CHECKPOINT_CONFIG = {
20+
"name": "legacy",
21+
"use_optimum_format": False,
1122
"weight_key": "packed_weight",
1223
"scale_key": "scale",
1324
"zero_point_key": "packed_zp",
@@ -31,14 +42,75 @@ def _default_lowp_checkpoint_config():
3142
return DEFAULT_LOWP_CHECKPOINT_CONFIG
3243

3344

45+
def _legacy_lowp_checkpoint_config():
46+
return LEGACY_LOWP_CHECKPOINT_CONFIG
47+
48+
3449
def _get_keys_from_config(checkpoint_config):
35-
weight_key = checkpoint_config.get("weight_key", "weight")
36-
scales_key = checkpoint_config.get("scale_key", "scale")
37-
zeros_key = checkpoint_config.get("zero_point_key", "zero")
50+
weight_key = checkpoint_config.get("weight_key", "qweight")
51+
scales_key = checkpoint_config.get("scale_key", "scales")
52+
zeros_key = checkpoint_config.get("zero_point_key", "qzeros")
3853
bias_key = checkpoint_config.get("bias_key", "bias")
3954
return weight_key, scales_key, zeros_key, bias_key
4055

4156

57+
def _convert_optimum_format_to_desired(qweight, scales, qzeros):
58+
"""
59+
Optimum format:
60+
qweight: (math.ceil(IC / comp_ratio), OC)
61+
scales: (n_groups, OC)
62+
qzeros: (n_groups, math.ceil(OC / comp_ratio))
63+
qzeros are substracted by 1 before packing
64+
65+
Desired format:
66+
compression_dim = 1
67+
qweight: (OC, math.ceil(IC / comp_ratio))
68+
scales: (OC, n_groups)
69+
qzeros: (OC, math.ceil(n_groups / comp_ratio))
70+
71+
Note:
72+
IC = input channels or input features
73+
OC = output channels or output features
74+
n_groups = math.ceil(IC / group_size)
75+
comp_ratio = compression data type bits // weight or zeros data type bits
76+
E.g., compression dtype = int32, weight dtype = int4, comp_ratio = 32 / 4 = 8
77+
78+
"""
79+
if qweight is None:
80+
return qweight, scales, qzeros
81+
oc = qweight.shape[1]
82+
assert oc == scales.shape[1]
83+
n_groups = scales.shape[0]
84+
qweight = qweight.t_().contiguous()
85+
scales = scales.t_().contiguous()
86+
if qzeros is None:
87+
return qweight, scales, qzeros
88+
zp_dtype = torch.int32
89+
zp = torch.empty((n_groups, oc), dtype=zp_dtype)
90+
# Steps to convert qzeros:
91+
# (1) unpack qzeros to (n_groups, OC)
92+
# (2) take transpose
93+
# (3) plus one and handle overflow
94+
zp_bits = 4 # int4
95+
comp_dtype_bits = 32 # int32
96+
comp_ratio = comp_dtype_bits // zp_bits
97+
mask = torch.tensor(2**zp_bits - 1, dtype=zp_dtype)
98+
for j in range(qzeros.shape[1]):
99+
packed_data = qzeros[:, j]
100+
for e in range(comp_ratio):
101+
index = j * comp_ratio + e
102+
if index >= zp.shape[1]:
103+
continue
104+
data = (packed_data >> (zp_bits * e)) & mask
105+
zp[:, index] = data.type(zp_dtype)
106+
zp = zp.t_().contiguous()
107+
zp += 1
108+
# it may overflow after adding one
109+
zp = torch.where(zp > (2**zp_bits - 1), 0, zp)
110+
111+
return qweight, scales, zp
112+
113+
42114
def _get_linear_parameters(attr_name, state_dict, checkpoint_config):
43115
weight_key, scales_key, zeros_key, bias_key = _get_keys_from_config(
44116
checkpoint_config
@@ -52,6 +124,13 @@ def _get_linear_parameters(attr_name, state_dict, checkpoint_config):
52124
scales = state_dict.get(s_key, None)
53125
qzeros = state_dict.get(z_key, None)
54126
bias = state_dict.get(b_key, None)
127+
128+
use_optimum_format = checkpoint_config.get("use_optimum_format", True)
129+
if use_optimum_format:
130+
qweight, scales, qzeros = _convert_optimum_format_to_desired(
131+
qweight, scales, qzeros
132+
)
133+
55134
group_size = -1
56135
if qweight is not None and scales is not None:
57136
assert scales.dim() == 2, "Unexpected scales tensor dimension"

tests/cpu/test_ipex_optimize_transformers.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,8 +373,8 @@ def test_static_quant_flow(self):
373373
if not hasattr(ipex_m, "trace_graph"):
374374
AssertionError(False)
375375

376-
def test_weight_only_quant_gptq(self):
377-
# import json
376+
def test_weight_only_quant_gptq_legacy(self):
377+
# Test the legacy format
378378
config = AutoConfig.from_pretrained(
379379
f"{curpath}/hf_configs/gptj", return_dict=False
380380
)
@@ -404,6 +404,74 @@ def test_weight_only_quant_gptq(self):
404404
torch.save(state_dict, checkpoint_file_name)
405405
state_dict = torch.load(checkpoint_file_name)
406406

407+
# test loading checkpoint and quant info
408+
lowp_mode = ipex.quantization.WoqLowpMode.INT8
409+
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
410+
lowp_mode=lowp_mode
411+
)
412+
config_dict = (
413+
ipex.utils.weight_only_quantization._legacy_lowp_checkpoint_config()
414+
)
415+
ipex_m = ipex.llm.optimize(
416+
ipex_m,
417+
dtype=torch.float,
418+
quantization_config=qconfig,
419+
low_precision_checkpoint=(state_dict, config_dict),
420+
deployment_mode=True,
421+
inplace=True,
422+
)
423+
assert hasattr(ipex_m, "trace_graph")
424+
425+
# Ensure model can run without errors
426+
with torch.no_grad():
427+
example_inputs = _get_gptj_example_inputs()
428+
# the optimized model is ipex_m.trace_graph
429+
ipex_m.trace_graph(*example_inputs)
430+
431+
def test_weight_only_quant_gptq(self):
432+
# Test the HuggingFace Optimum format
433+
config = AutoConfig.from_pretrained(
434+
f"{curpath}/hf_configs/gptj", return_dict=False
435+
)
436+
m = transformers.models.gptj.modeling_gptj.GPTJForCausalLM(config).eval()
437+
ipex_m = copy.deepcopy(m)
438+
with tempfile.TemporaryDirectory() as work_dir:
439+
# Generate dummy checkpoint
440+
checkpoint_file_name = work_dir + "/checkpoint.pt"
441+
state_dict = ipex_m.state_dict()
442+
linear_keys = []
443+
for k, v in state_dict.items():
444+
if any(
445+
k.endswith(suffix)
446+
for suffix in ["proj.weight", "fc_in.weight", "fc_out.weight"]
447+
):
448+
linear_keys.append(k[:-7])
449+
group_size = 128
450+
comp_ratio = 8
451+
for k in linear_keys:
452+
N = state_dict[k + ".weight"].shape[0]
453+
K = state_dict[k + ".weight"].shape[1]
454+
del state_dict[k + ".weight"]
455+
n_groups = K // group_size
456+
stored_weight_shape = (K // comp_ratio, N)
457+
stored_scales_shape = (n_groups, N)
458+
stored_zeros_shape = (n_groups, N // comp_ratio)
459+
state_dict[k + ".qweight"] = torch.randint(
460+
-(2**31), 2**31 - 1, stored_weight_shape, dtype=torch.int32
461+
)
462+
state_dict[k + ".scales"] = torch.randn(
463+
stored_scales_shape, dtype=torch.half
464+
)
465+
state_dict[k + ".qzeros"] = torch.randint(
466+
-(2**31), 2**31 - 1, stored_zeros_shape, dtype=torch.int32
467+
)
468+
g_idx = torch.arange(n_groups).repeat(group_size)
469+
g_idx[:] = g_idx[torch.randperm(K)]
470+
state_dict[k + ".g_idx"] = g_idx
471+
472+
torch.save(state_dict, checkpoint_file_name)
473+
state_dict = torch.load(checkpoint_file_name)
474+
407475
# test loading checkpoint and quant info
408476
lowp_mode = ipex.quantization.WoqLowpMode.INT8
409477
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(

tests/cpu/test_quantization_default_recipe.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -729,14 +729,6 @@ def _get_gptj_example_inputs():
729729
self.assertTrue(torch.allclose(out0[0], out1[0], atol=1e-05))
730730

731731
low_precision_checkpoint = torch.load(work_dir + "/gptq_checkpoint_g128.pt")
732-
config_dict = {
733-
"weight_key": "qweight",
734-
"scale_key": "scales",
735-
"zero_point_key": "qzeros",
736-
"bias_key": "bias",
737-
"g_idx_key": "g_idx",
738-
}
739-
state_dict_and_config = (low_precision_checkpoint, config_dict)
740732
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
741733
weight_dtype=torch.quint4x2,
742734
lowp_mode=ipex.quantization.WoqLowpMode.INT8,
@@ -748,7 +740,7 @@ def _get_gptj_example_inputs():
748740
dtype=torch.float,
749741
quantization_config=qconfig,
750742
inplace=True,
751-
low_precision_checkpoint=state_dict_and_config,
743+
low_precision_checkpoint=low_precision_checkpoint,
752744
deployment_mode=False,
753745
)
754746
_IPEXAttentionCPU = (

0 commit comments

Comments
 (0)