@@ -438,7 +438,7 @@ def quantize(self, module):
438
438
),
439
439
)
440
440
else :
441
- self .quantize (module )
441
+ self .quantize (child )
442
442
443
443
return module
444
444
@@ -450,31 +450,6 @@ def quantized_model(self) -> nn.Module:
450
450
##### embedding table quantization ######
451
451
452
452
453
- def replace_embedding_weight_only_grouped_int8_per_channel (
454
- module , device , bitwidth : int , groupsize : Optional [int ]
455
- ):
456
- for name , child in module .named_children ():
457
- # print(f"name: {name}")
458
- if isinstance (child , nn .Embedding ):
459
- # print(f"{name, child}")
460
- # print(f"weights size: {child.weight.size()}")
461
- setattr (
462
- module ,
463
- name ,
464
- QuantizedEmbedding (
465
- device = device ,
466
- num_embeddings = child .weight .shape [0 ],
467
- embedding_dim = child .weight .shape [1 ],
468
- bitwidth = bitwidth ,
469
- groupsize = groupsize ,
470
- ),
471
- )
472
- else :
473
- replace_embedding_weight_only_grouped_int8_per_channel (
474
- child , device , bitwidth , groupsize
475
- )
476
-
477
-
478
453
class EmbeddingOnlyInt8QuantHandler (QuantHandler ):
479
454
def __init__ (
480
455
self ,
@@ -492,9 +467,11 @@ def __init__(
492
467
self .bitwidth = bitwidth
493
468
494
469
@torch .no_grad ()
495
- def create_quantized_state_dict (self ) -> Dict :
496
- cur_state_dict = state_dict_device (self .model_ .state_dict ())
497
- dict_device = "cpu" # self.device
470
+ def quantize (self , module ):
471
+ # cur_state_dict = state_dict_device(self.model_.state_dict())
472
+ # dict_device = "cpu" # self.device
473
+
474
+ device = self .device
498
475
499
476
if self .bitwidth == 4 :
500
477
range_min = - 8
@@ -505,22 +482,23 @@ def create_quantized_state_dict(self) -> Dict:
505
482
else :
506
483
raise ValueError (f"Unsupported bitwidth { self .bitwidth } " )
507
484
508
- for fqn , mod in self .model_ .named_modules ():
509
- if isinstance (mod , nn .Embedding ):
485
+ for name , child in module .named_children ():
486
+ # print(f"name: {name}")
487
+ if isinstance (child , nn .Embedding ):
510
488
# print(f"Embedding identified: {fqn, mod}")
511
- # print(f"weights size: {mod .weight.size()}")
489
+ # print(f"weights size: {child .weight.size()}")
512
490
# print(f"quantize {fqn}...")
513
491
514
492
# print(
515
493
# f"quantize {fqn, mod} with groupsize {self.groupsize}, bitwidth {self.bitwidth}"
516
494
# )
517
495
weight , scales , _ = dynamically_quantize_per_channel (
518
- mod .weight .float (),
496
+ child .weight .float (),
519
497
range_min ,
520
498
range_max ,
521
499
torch .int8 ,
522
500
self .groupsize ,
523
- scales_dtype = mod .weight .dtype ,
501
+ scales_dtype = child .weight .dtype ,
524
502
)
525
503
526
504
if self .bitwidth == 4 :
@@ -536,26 +514,31 @@ def create_quantized_state_dict(self) -> Dict:
536
514
weight_packed = weight_even + weight_odd
537
515
weight = weight_packed
538
516
539
- weight = weight .to (device = dict_device )
540
- scales = scales .to (device = dict_device )
541
- # Update state dict
542
- cur_state_dict [f"{ fqn } .weight" ] = weight
543
- # squeeze makes groupsize=rowsize unidimensional
544
- cur_state_dict [f"{ fqn } .scales" ] = scales .squeeze (dim = - 1 )
517
+ weight = weight
518
+ scales = scales .squeeze (dim = - 1 )
545
519
546
- return cur_state_dict
520
+ # print(f"{name, child}")
521
+ # print(f"weights size: {child.weight.size()}")
522
+ setattr (
523
+ module ,
524
+ name ,
525
+ QuantizedEmbedding (
526
+ num_embeddings = child .weight .shape [0 ],
527
+ embedding_dim = child .weight .shape [1 ],
528
+ device = self .device ,
529
+ bitwidth = self .bitwidth ,
530
+ groupsize = self .groupsize ,
531
+ weight = weight ,
532
+ scales = scales ,
533
+ ),
534
+ )
535
+ else :
536
+ self .quantize (child )
547
537
548
- def convert_for_runtime (self ) -> nn .Module :
549
- replace_embedding_weight_only_grouped_int8_per_channel (
550
- self .model_ , self .device , self .bitwidth , self .groupsize
551
- )
552
- return self .model_
538
+ return module
553
539
554
540
def quantized_model (self ) -> nn .Module :
555
- model_updated_state_dict = self .create_quantized_state_dict ()
556
- self .convert_for_runtime ()
557
- self .model_ .load_state_dict (model_updated_state_dict )
558
- return self .model_
541
+ return self .quantize (self .model_ )
559
542
560
543
561
544
#########################################################################
0 commit comments