22
22
23
23
24
24
try :
25
- # pyre-ignore
25
+ # pyre-ignore[21]: Undefined import.
26
26
from fairseq2 .nn .embedding import (
27
27
Embedding as fsEmbedding ,
28
28
StandardEmbedding as fsStandardEmbedding ,
29
29
)
30
30
31
- # pyre-ignore
31
+ # pyre-ignore[21]: Undefined import.
32
32
from fairseq2 .nn .projection import Linear as fsLinear
33
33
except :
34
34
print ("Could not import fairseq2 modules." )
@@ -645,14 +645,6 @@ def create_quantized_state_dict(self) -> Dict:
645
645
646
646
# print(f"initial weight shape {mod.weight.shape}")
647
647
input_weight = mod .weight .float ()
648
- input_weight_shape_1 = input_weight .shape [1 ]
649
- if (self .group_size is not None ) and (
650
- input_weight_shape_1 % self .group_size != 0
651
- ):
652
- padding = self .group_size - (
653
- input_weight_shape_1 % self .group_size
654
- )
655
- input_weight = F .pad (input_weight , (0 , padding ))
656
648
657
649
# print(f"expanded weight shape {input_weight.shape}")
658
650
weight , scales , _ = dynamically_quantize_per_channel (
@@ -663,9 +655,8 @@ def create_quantized_state_dict(self) -> Dict:
663
655
self .group_size ,
664
656
scales_dtype = mod .weight .dtype ,
665
657
)
666
- unpadded_weight = weight [:, :input_weight_shape_1 ]
667
658
668
- cur_state_dict [f"{ fqn } .weight" ] = unpadded_weight
659
+ cur_state_dict [f"{ fqn } .weight" ] = weight
669
660
# squeeze makes groupsize=rowsize unidimensional
670
661
cur_state_dict [f"{ fqn } .scales" ] = scales .squeeze (dim = - 1 )
671
662
0 commit comments