15
15
import torch
16
16
import torch .nn as nn
17
17
import torch .nn .functional as F
18
- from build .utils import find_multiple , get_precision
18
+ from build .utils import find_multiple , get_precision , use_et_backend
19
19
20
20
21
21
#########################################################################
@@ -92,30 +92,6 @@ def quantized_model(self) -> nn.Module:
92
92
return self .quantizer .quantize (self .model_ )
93
93
94
94
95
- #########################################################################
96
- ### QuantHandler API definition ###
97
- ### (unify with torchao in future) ###
98
-
99
-
100
- class QuantHandler :
101
- def __init__ (self , model : nn .Module , device = "cpu" , tokenizer = None ):
102
- self .model_ = model
103
- self .device = device
104
- self .tokenizer = tokenizer
105
-
106
- def create_quantized_state_dict (self ) -> Dict : # "StateDict"
107
- pass
108
-
109
- def convert_for_runtime (self ) -> nn .Module :
110
- pass
111
-
112
- def quantized_model (self ) -> nn .Module :
113
- model_updated_state_dict = self .create_quantized_state_dict ()
114
- self .convert_for_runtime ()
115
- self .model_ .load_state_dict (model_updated_state_dict )
116
- return self .model_
117
-
118
-
119
95
#########################################################################
120
96
### wrapper for setting precision as a QuantHandler ###
121
97
@@ -521,7 +497,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
521
497
522
498
523
499
def replace_embedding_weight_only_grouped_int8_per_channel (
524
- module , device , bitwidth : int = 8 , groupsize : Optional [int ] = None , packed = False
500
+ module , device , bitwidth : int , groupsize : Optional [int ]
525
501
):
526
502
for name , child in module .named_children ():
527
503
# print(f"name: {name}")
@@ -535,13 +511,13 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
535
511
device = device ,
536
512
vocab_size = child .weight .shape [0 ],
537
513
embedding_dim = child .weight .shape [1 ],
514
+ bitwidth = bitwidth ,
538
515
groupsize = groupsize ,
539
- packed = packed ,
540
516
),
541
517
)
542
518
else :
543
519
replace_embedding_weight_only_grouped_int8_per_channel (
544
- child , device , bitwidth , groupsize , packed
520
+ child , device , bitwidth , groupsize
545
521
)
546
522
547
523
@@ -554,19 +530,15 @@ def __init__(
554
530
* ,
555
531
bitwidth : int = 8 ,
556
532
groupsize : Optional [int ] = None ,
557
- packed = True ,
533
+ packed = True , # we always pack bitwidth 4 now
558
534
):
559
- # when quantization dictionary comes from JSON, packed is a string
560
- if isinstance (packed , str ):
561
- packed = packed .lower () != "false"
562
535
self .model_ = model
563
536
self .device = device
564
537
self .groupsize = groupsize
565
538
self .bitwidth = bitwidth
566
- self .packed = packed
567
539
568
540
@torch .no_grad ()
569
- def create_quantized_state_dict (self , packed = False ) -> Dict :
541
+ def create_quantized_state_dict (self ) -> Dict :
570
542
cur_state_dict = self .model_ .state_dict ()
571
543
572
544
if self .bitwidth == 4 :
@@ -596,7 +568,7 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
596
568
scales_dtype = mod .weight .dtype ,
597
569
)
598
570
599
- if packed :
571
+ if self . bitwidth == 4 :
600
572
if weight .shape [- 1 ] % 2 != 0 :
601
573
raise RuntimeError ("automatic padding not implemented yet" )
602
574
@@ -620,12 +592,12 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
620
592
621
593
def convert_for_runtime (self ) -> nn .Module :
622
594
replace_embedding_weight_only_grouped_int8_per_channel (
623
- self .model_ , self .device , self .bitwidth , self .groupsize , self . packed
595
+ self .model_ , self .device , self .bitwidth , self .groupsize
624
596
)
625
597
return self .model_
626
598
627
599
def quantized_model (self ) -> nn .Module :
628
- model_updated_state_dict = self .create_quantized_state_dict (self . packed )
600
+ model_updated_state_dict = self .create_quantized_state_dict ()
629
601
self .convert_for_runtime ()
630
602
self .model_ .load_state_dict (model_updated_state_dict )
631
603
return self .model_
@@ -637,30 +609,42 @@ def __init__(
637
609
device ,
638
610
vocab_size : int ,
639
611
embedding_dim : int ,
612
+ bitwidth : int ,
640
613
groupsize : Optional [int ] = None ,
614
+ * ,
641
615
dtype = torch .half ,
642
- packed = False ,
643
616
) -> None :
644
617
super ().__init__ ()
645
618
if groupsize is None or groupsize == 0 :
646
619
groupsize = embedding_dim
647
620
self .groupsize = groupsize
648
621
self .dtype = dtype
649
- self .packed = packed
650
- if not packed :
622
+ self .bitwidth = bitwidth
623
+
624
+ if use_et_backend ():
625
+ self .forward = self .et_forward
626
+ else :
627
+ self .forward = self .aoti_forward
628
+
629
+ if bitwidth == 8 :
651
630
self .register_buffer (
652
631
"weight" ,
653
632
torch .empty (
654
633
(vocab_size , embedding_dim ), dtype = torch .int8 , device = device
655
634
),
656
635
)
657
- else : # packed
636
+ elif bitwidth == 4 : # packed
658
637
self .register_buffer (
659
638
"weight" ,
660
639
torch .empty (
661
640
(vocab_size , embedding_dim // 2 ), dtype = torch .uint8 , device = device
662
641
),
663
642
)
643
+ else :
644
+ raise RuntimeError (
645
+ f"QUantized embedding does not support bitwidth={ bitwidth } "
646
+ )
647
+
664
648
groups_per_row = (embedding_dim + groupsize - 1 ) // groupsize
665
649
if groups_per_row > 1 :
666
650
self .register_buffer (
@@ -675,16 +659,22 @@ def __init__(
675
659
)
676
660
677
661
@torch .no_grad ()
678
- def forward (self , indices : torch .Tensor ) -> torch .Tensor :
679
- if False : # Used for Executorch
680
- return torch .ops .llama_quantized .embedding_byte .dtype (
662
+ def et_forward (self , indices : torch .Tensor ) -> torch .Tensor :
663
+ if self .bitwidth == 8 :
664
+ return torch .ops .quantized_decomposed .embedding_byte .dtype (
665
+ self .weight , self .scales , None , 0 , 0 , indices , dtype = self .dtype
666
+ )
667
+ else :
668
+ return torch .ops .quantized_decomposed .embedding_4bit .dtype (
681
669
self .weight , self .scales , None , 0 , 0 , indices , dtype = self .dtype
682
670
)
683
671
672
+ @torch .no_grad ()
673
+ def aoti_forward (self , indices : torch .Tensor ) -> torch .Tensor :
684
674
# result_weights = self.weight.index_select(0, indices.view(-1))
685
675
# result_scales = self.scales.index_select(0, indices.view(-1))
686
676
687
- if self .packed :
677
+ if self .bitwidth == 4 :
688
678
weight_even = self .weight .div (16 , rounding_mode = "trunc" )
689
679
weight_odd = self .weight .remainder (16 )
690
680
weight_unpacked = torch .stack ((weight_even , weight_odd ), dim = - 1 )
0 commit comments