@@ -494,6 +494,7 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
494
494
group_size = group_size ,
495
495
dtype = child .weight .dtype ,
496
496
packed = packed ,
497
+ bitwidth = bitwidth ,
497
498
),
498
499
)
499
500
else :
@@ -519,14 +520,17 @@ def __init__(
519
520
self .group_size = group_size
520
521
self .bitwidth = bitwidth
521
522
self .packed = packed
522
- if (bitwidth != 4 ) and packed :
523
- raise RuntimeError ("pack only works with bitsize 4" )
523
+ if (bitwidth not in [ 2 , 4 ] ) and packed :
524
+ raise RuntimeError ("pack only works with bitsize 2, 4" )
524
525
525
526
@torch .no_grad ()
526
527
def create_quantized_state_dict (self , packed = False ) -> Dict :
527
528
cur_state_dict = self .mod .state_dict ()
528
529
529
- if self .bitwidth == 4 :
530
+ if self .bitwidth == 2 :
531
+ range_min = - 2
532
+ range_max = 1
533
+ elif self .bitwidth == 4 :
530
534
range_min = - 8
531
535
range_max = 7
532
536
elif self .bitwidth == 8 :
@@ -555,17 +559,30 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
555
559
)
556
560
557
561
if packed :
558
- if weight .shape [- 1 ] % 2 != 0 :
559
- raise RuntimeError ("automatic padding not implemented yet" )
560
-
561
- weight_range_shifted = weight .add (8 ).view (torch .uint8 )
562
- weight_view = weight_range_shifted .view (
563
- weight .shape [0 ], weight .shape [1 ] // 2 , 2
564
- )
565
- weight_even = weight_view [:, :, 0 ] * 16 # left shift 4
566
- weight_odd = weight_view [:, :, 1 ]
567
- weight_packed = weight_even + weight_odd
568
- weight = weight_packed
562
+ if self .bitwidth == 2 :
563
+ if weight .shape [- 1 ] % 4 != 0 :
564
+ raise RuntimeError ("automatic padding not implemented yet" )
565
+ weight_range_shifted = weight .add (2 ).view (torch .uint8 )
566
+ weight_view = weight_range_shifted .view (
567
+ weight .shape [0 ], weight .shape [1 ] // 4 , 4
568
+ )
569
+ weight_0 = weight_view [:, :, 0 ]
570
+ weight_1 = weight_view [:, :, 1 ] << 2
571
+ weight_2 = weight_view [:, :, 2 ] << 4
572
+ weight_3 = weight_view [:, :, 3 ] << 6
573
+ weight_packed = weight_0 + weight_1 + weight_2 + weight_3
574
+ weight = weight_packed
575
+ elif self .bitwidth == 4 :
576
+ if weight .shape [- 1 ] % 2 != 0 :
577
+ raise RuntimeError ("automatic padding not implemented yet" )
578
+ weight_range_shifted = weight .add (8 ).view (torch .uint8 )
579
+ weight_view = weight_range_shifted .view (
580
+ weight .shape [0 ], weight .shape [1 ] // 2 , 2
581
+ )
582
+ weight_even = weight_view [:, :, 0 ] * 16 # left shift 4
583
+ weight_odd = weight_view [:, :, 1 ]
584
+ weight_packed = weight_even + weight_odd
585
+ weight = weight_packed
569
586
570
587
weight = weight .to (device = self .device )
571
588
scales = scales .to (device = self .device )
@@ -598,13 +615,15 @@ def __init__(
598
615
group_size : Optional [int ] = None ,
599
616
dtype = torch .half ,
600
617
packed = False ,
618
+ bitwidth : int = 8 ,
601
619
) -> None :
602
620
super ().__init__ ()
603
621
if group_size is None or group_size == 0 :
604
622
group_size = embedding_dim
605
623
self .group_size = group_size
606
624
self .dtype = dtype
607
625
self .packed = packed
626
+ self .bitwidth = bitwidth
608
627
if not packed :
609
628
self .register_buffer (
610
629
"weight" ,
@@ -613,12 +632,25 @@ def __init__(
613
632
),
614
633
)
615
634
else : # packed
616
- self .register_buffer (
617
- "weight" ,
618
- torch .empty (
619
- (vocab_size , embedding_dim // 2 ), dtype = torch .uint8 , device = device
620
- ),
621
- )
635
+ if bitwidth == 2 :
636
+ self .register_buffer (
637
+ "weight" ,
638
+ torch .empty (
639
+ (vocab_size , embedding_dim // 4 ),
640
+ dtype = torch .uint8 ,
641
+ device = device ,
642
+ ),
643
+ )
644
+ elif bitwidth == 4 :
645
+ self .register_buffer (
646
+ "weight" ,
647
+ torch .empty (
648
+ (vocab_size , embedding_dim // 2 ),
649
+ dtype = torch .uint8 ,
650
+ device = device ,
651
+ ),
652
+ )
653
+
622
654
groups_per_row = (embedding_dim + group_size - 1 ) // group_size
623
655
if groups_per_row > 1 :
624
656
self .register_buffer (
@@ -638,7 +670,14 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
638
670
return torch .ops .quantized_decomposed .embedding_byte .dtype (
639
671
self .weight , self .scales , None , 0 , 0 , indices , dtype = self .dtype
640
672
)
641
- else : # 4bit packed
673
+ else : # packed
674
+ if self .bitwidth == 2 :
675
+ return torch .ops .quantized_decomposed .embedding_2bit .dtype (
676
+ self .weight , self .scales , None , 0 , 0 , indices , dtype = self .dtype
677
+ )
678
+
679
+ # Remaining case (always return to make pyre happy)
680
+ assert self .bitwidth == 4
642
681
return torch .ops .quantized_decomposed .embedding_4bit .dtype (
643
682
self .weight , self .scales , None , 0 , 0 , indices , dtype = self .dtype
644
683
)
@@ -658,7 +697,7 @@ def get_quant_embedding_transform(args):
658
697
model ,
659
698
bitwidth = bitwidth ,
660
699
group_size = group_size ,
661
- packed = (bitwidth == 4 ),
700
+ packed = (bitwidth in [ 2 , 4 ] ),
662
701
).quantized_model ()
663
702
664
703
0 commit comments