Skip to content

Commit e4a51a8

Browse files
mikekgfbmalfet
authored andcommitted
Unified quantizer (#624)
* remove debug print statements and run linter * use unified quantizer architecture * use unified quantizer architecture * use unified quantizer architecture * typos & lint * typos & lint
1 parent 24a7f61 commit e4a51a8

File tree

2 files changed

+71
-85
lines changed

2 files changed

+71
-85
lines changed

qops.py

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
state_dict_device,
1313
use_et_backend,
1414
)
15-
from torch.nn.parameter import Parameter
15+
16+
# from torch.nn.parameter import Parameter
1617

1718

1819
def linear_int8_aoti(input, weight, scales):
@@ -147,7 +148,9 @@ def __init__(
147148
), "must specify both weights and scales, or neither"
148149
if weight is None:
149150
weight = torch.empty(
150-
(out_features, in_features), dtype=torch.int8, device=device
151+
(out_features, in_features),
152+
dtype=torch.int8,
153+
device=device,
151154
)
152155
if groupsize is None or (groupsize == 0):
153156
scales = torch.empty(out_features, dtype=dtype, device=device)
@@ -180,56 +183,56 @@ def __init__(
180183
*,
181184
bitwidth: int,
182185
groupsize: Optional[int] = None,
186+
weight: Optional[torch.Tensor] = None,
187+
scales: Optional[torch.Tensor] = None,
183188
) -> None:
184189
super().__init__()
185190
if dtype is None:
186-
dtype = torch.half
187-
191+
dtype = get_precision()
188192
if groupsize is None or groupsize == 0:
189193
groupsize = embedding_dim
190194
self.groupsize = groupsize
191195
self.dtype = dtype
192196
self.bitwidth = bitwidth
193197

194-
if use_et_backend():
195-
self.forward = self.et_forward
196-
else:
197-
self.forward = self.aoti_forward
198+
assert (weight is None) == bool(
199+
scales is None
200+
), "must specify both weights and scales, or neither"
198201

199-
if bitwidth == 8:
200-
self.register_buffer(
201-
"weight",
202-
torch.empty(
203-
(num_embeddings, embedding_dim), dtype=torch.int8, device=device
204-
),
205-
)
206-
elif bitwidth == 4: # packed
207-
self.register_buffer(
208-
"weight",
209-
torch.empty(
210-
(num_embeddings, embedding_dim // 2),
211-
dtype=torch.uint8,
212-
device=device,
213-
),
214-
)
215-
else:
202+
if bitwidth not in [4, 8]:
216203
raise RuntimeError(
217204
f"QUantized embedding does not support bitwidth={bitwidth}"
218205
)
219206

220-
groups_per_row = (embedding_dim + groupsize - 1) // groupsize
221-
if groups_per_row > 1:
222-
self.register_buffer(
223-
"scales",
224-
torch.ones(
225-
(num_embeddings, groups_per_row), dtype=torch.float16, device=device
207+
if weight is None:
208+
groups_per_row = (embedding_dim + groupsize - 1) // groupsize
209+
weight = torch.empty(
210+
(
211+
num_embeddings,
212+
(embedding_dim * bitwidth) // 8,
226213
),
214+
dtype=torch.int8,
215+
device=device,
227216
)
217+
scales = torch.empty(
218+
(num_embeddings, groups_per_row),
219+
dtype=dtype,
220+
device=device,
221+
).squeeze(dim=-1)
222+
223+
self.register_buffer(
224+
"weight",
225+
weight,
226+
)
227+
self.register_buffer(
228+
"scales",
229+
scales,
230+
)
231+
232+
if use_et_backend():
233+
self.forward = self.et_forward
228234
else:
229-
self.register_buffer(
230-
"scales",
231-
torch.ones((num_embeddings,), dtype=torch.float16, device=device),
232-
)
235+
self.forward = self.aoti_forward
233236

234237
@torch.no_grad()
235238
def et_forward(self, indices: torch.Tensor) -> torch.Tensor:

quantize.py

Lines changed: 33 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ def quantize(self, module):
438438
),
439439
)
440440
else:
441-
self.quantize(module)
441+
self.quantize(child)
442442

443443
return module
444444

@@ -450,31 +450,6 @@ def quantized_model(self) -> nn.Module:
450450
##### embedding table quantization ######
451451

452452

453-
def replace_embedding_weight_only_grouped_int8_per_channel(
454-
module, device, bitwidth: int, groupsize: Optional[int]
455-
):
456-
for name, child in module.named_children():
457-
# print(f"name: {name}")
458-
if isinstance(child, nn.Embedding):
459-
# print(f"{name, child}")
460-
# print(f"weights size: {child.weight.size()}")
461-
setattr(
462-
module,
463-
name,
464-
QuantizedEmbedding(
465-
device=device,
466-
num_embeddings=child.weight.shape[0],
467-
embedding_dim=child.weight.shape[1],
468-
bitwidth=bitwidth,
469-
groupsize=groupsize,
470-
),
471-
)
472-
else:
473-
replace_embedding_weight_only_grouped_int8_per_channel(
474-
child, device, bitwidth, groupsize
475-
)
476-
477-
478453
class EmbeddingOnlyInt8QuantHandler(QuantHandler):
479454
def __init__(
480455
self,
@@ -492,9 +467,11 @@ def __init__(
492467
self.bitwidth = bitwidth
493468

494469
@torch.no_grad()
495-
def create_quantized_state_dict(self) -> Dict:
496-
cur_state_dict = state_dict_device(self.model_.state_dict())
497-
dict_device = "cpu" # self.device
470+
def quantize(self, module):
471+
# cur_state_dict = state_dict_device(self.model_.state_dict())
472+
# dict_device = "cpu" # self.device
473+
474+
device = self.device
498475

499476
if self.bitwidth == 4:
500477
range_min = -8
@@ -505,22 +482,23 @@ def create_quantized_state_dict(self) -> Dict:
505482
else:
506483
raise ValueError(f"Unsupported bitwidth {self.bitwidth}")
507484

508-
for fqn, mod in self.model_.named_modules():
509-
if isinstance(mod, nn.Embedding):
485+
for name, child in module.named_children():
486+
# print(f"name: {name}")
487+
if isinstance(child, nn.Embedding):
510488
# print(f"Embedding identified: {fqn, mod}")
511-
# print(f"weights size: {mod.weight.size()}")
489+
# print(f"weights size: {child.weight.size()}")
512490
# print(f"quantize {fqn}...")
513491

514492
# print(
515493
# f"quantize {fqn, mod} with groupsize {self.groupsize}, bitwidth {self.bitwidth}"
516494
# )
517495
weight, scales, _ = dynamically_quantize_per_channel(
518-
mod.weight.float(),
496+
child.weight.float(),
519497
range_min,
520498
range_max,
521499
torch.int8,
522500
self.groupsize,
523-
scales_dtype=mod.weight.dtype,
501+
scales_dtype=child.weight.dtype,
524502
)
525503

526504
if self.bitwidth == 4:
@@ -536,26 +514,31 @@ def create_quantized_state_dict(self) -> Dict:
536514
weight_packed = weight_even + weight_odd
537515
weight = weight_packed
538516

539-
weight = weight.to(device=dict_device)
540-
scales = scales.to(device=dict_device)
541-
# Update state dict
542-
cur_state_dict[f"{fqn}.weight"] = weight
543-
# squeeze makes groupsize=rowsize unidimensional
544-
cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1)
517+
weight = weight
518+
scales = scales.squeeze(dim=-1)
545519

546-
return cur_state_dict
520+
# print(f"{name, child}")
521+
# print(f"weights size: {child.weight.size()}")
522+
setattr(
523+
module,
524+
name,
525+
QuantizedEmbedding(
526+
num_embeddings=child.weight.shape[0],
527+
embedding_dim=child.weight.shape[1],
528+
device=self.device,
529+
bitwidth=self.bitwidth,
530+
groupsize=self.groupsize,
531+
weight=weight,
532+
scales=scales,
533+
),
534+
)
535+
else:
536+
self.quantize(child)
547537

548-
def convert_for_runtime(self) -> nn.Module:
549-
replace_embedding_weight_only_grouped_int8_per_channel(
550-
self.model_, self.device, self.bitwidth, self.groupsize
551-
)
552-
return self.model_
538+
return module
553539

554540
def quantized_model(self) -> nn.Module:
555-
model_updated_state_dict = self.create_quantized_state_dict()
556-
self.convert_for_runtime()
557-
self.model_.load_state_dict(model_updated_state_dict)
558-
return self.model_
541+
return self.quantize(self.model_)
559542

560543

561544
#########################################################################

0 commit comments

Comments
 (0)