15
15
import torch
16
16
import torch .nn as nn
17
17
import torch .nn .functional as F
18
- from build .utils import find_multiple , get_precision , name_to_dtype , use_et_backend
18
+ from build .utils import find_multiple , get_precision , name_to_dtype , use_et_backend , state_dict_device
19
19
20
20
21
21
#########################################################################
@@ -63,7 +63,7 @@ def convert_for_runtime(self) -> nn.Module:
63
63
pass
64
64
65
65
def quantized_model (self ) -> nn .Module :
66
- model_updated_state_dict = self .create_quantized_state_dict ()
66
+ model_updated_state_dict = state_dict_device ( self .create_quantized_state_dict () )
67
67
self .convert_for_runtime ()
68
68
self .model_ .load_state_dict (model_updated_state_dict )
69
69
return self .model_
@@ -406,8 +406,9 @@ def __init__(
406
406
407
407
@torch .no_grad ()
408
408
def create_quantized_state_dict (self ) -> Dict :
409
- cur_state_dict = self .model_ .state_dict ()
410
-
409
+ cur_state_dict = state_dict_device (self .model_ .state_dict ())
410
+ dict_device = "cpu" # self.device
411
+
411
412
if self .bitwidth == 4 :
412
413
range_min = - 8
413
414
range_max = 7
@@ -446,8 +447,8 @@ def create_quantized_state_dict(self) -> Dict:
446
447
scales_dtype = mod .weight .dtype ,
447
448
)
448
449
449
- weight = weight .to (device = self . device )
450
- scales = scales .to (device = self . device )
450
+ weight = weight .to (device = dict_device )
451
+ scales = scales .to (device = dict_device )
451
452
cur_state_dict [f"{ fqn } .weight" ] = weight
452
453
# squeeze makes groupsize=rowsize unidimensional
453
454
cur_state_dict [f"{ fqn } .scales" ] = scales .squeeze (dim = - 1 )
@@ -553,7 +554,8 @@ def __init__(
553
554
554
555
@torch .no_grad ()
555
556
def create_quantized_state_dict (self ) -> Dict :
556
- cur_state_dict = self .model_ .state_dict ()
557
+ cur_state_dict = state_dict_device (self .model_ .state_dict ())
558
+ dict_device = "cpu" # self.device
557
559
558
560
if self .bitwidth == 4 :
559
561
range_min = - 8
@@ -595,8 +597,8 @@ def create_quantized_state_dict(self) -> Dict:
595
597
weight_packed = weight_even + weight_odd
596
598
weight = weight_packed
597
599
598
- weight = weight .to (device = self . device )
599
- scales = scales .to (device = self . device )
600
+ weight = weight .to (device = dict_device )
601
+ scales = scales .to (device = dict_device )
600
602
# Update state dict
601
603
cur_state_dict [f"{ fqn } .weight" ] = weight
602
604
# squeeze makes groupsize=rowsize unidimensional
@@ -822,9 +824,21 @@ def __init__(
822
824
assert groupsize in [32 , 64 , 128 , 256 ]
823
825
assert inner_k_tiles in [2 , 4 , 8 ]
824
826
827
+
828
+ # @torch.no_grad()
829
+ # def p(self):
830
+ # cur_state_dict = state_dict_device(self.model_.state_dict())
831
+ # dict_device = "cpu" # self.device
832
+ #
833
+ # for fqn, mod in self.model_.named_modules():
834
+ # if hasattr(mod, "weight"):
835
+ # print(f"device={str(mod.weight.data.device)}")
836
+
825
837
@torch .no_grad ()
826
838
def create_quantized_state_dict (self ):
827
- cur_state_dict = self .model_ .state_dict ()
839
+ cur_state_dict = state_dict_device (self .model_ .state_dict ())
840
+ dict_device = "cpu" # self.device
841
+
828
842
for fqn , mod in self .model_ .named_modules ():
829
843
if isinstance (mod , torch .nn .Linear ):
830
844
assert not mod .bias
@@ -856,8 +870,8 @@ def create_quantized_state_dict(self):
856
870
weight .to (torch .float ), self .groupsize , self .inner_k_tiles
857
871
)
858
872
)
859
- weight_int4pack = weight_int4pack .to (device = self . device )
860
- scales_and_zeros = scales_and_zeros .to (device = self . device )
873
+ weight_int4pack = weight_int4pack .to (device = dict_device )
874
+ scales_and_zeros = scales_and_zeros .to (device = dict_device )
861
875
cur_state_dict [f"{ fqn } .weight" ] = weight_int4pack
862
876
cur_state_dict [f"{ fqn } .scales_and_zeros" ] = scales_and_zeros
863
877
@@ -877,6 +891,7 @@ def quantized_model(self) -> nn.Module:
877
891
model_updated_state_dict = self .create_quantized_state_dict ()
878
892
self .convert_for_runtime ()
879
893
self .model_ .load_state_dict (model_updated_state_dict )
894
+ # self.p()
880
895
return self .model_
881
896
882
897
0 commit comments