@@ -98,21 +98,38 @@ def quantize( # noqa C901
98
98
matches = re .findall (pattern , qmode )
99
99
assert len (matches ) == 1 , f"Expected 1 match for pattern but got { len (matches )} "
100
100
bitwidth = int (matches [0 ][0 ])
101
- _load_torchao_aten_lib (libname = "libtorchao_ops_aten" )
102
- from torchao .experimental .quant_api import Int8DynActIntxWeightLinearQuantizer
101
+ # _load_torchao_aten_lib(libname="libtorchao_ops_aten")
102
+ # from torchao.experimental.quant_api import Int8DynActIntxWeightLinearQuantizer
103
+ from torchao .experimental .quant_api import int8_dynamic_activation_intx_weight , Int8DynActIntxWeightLinearQuantizer
104
+ from torchao .quantization .quant_api import quantize_
105
+ from torchao .utils import unwrap_tensor_subclass
106
+ from torchao .quantization .granularity import PerRow , PerGroup
103
107
104
108
with torch .no_grad ():
105
- model = Int8DynActIntxWeightLinearQuantizer (
106
- device = "cpu" ,
107
- precision = torch .float32 ,
108
- groupsize = group_size ,
109
- bitwidth = bitwidth ,
110
- has_weight_zeros = False ,
111
- ).quantize (model )
112
-
109
+ # model = Int8DynActIntxWeightLinearQuantizer(
110
+ # device="cpu",
111
+ # precision=torch.float32,
112
+ # groupsize=group_size,
113
+ # bitwidth=bitwidth,
114
+ # has_weight_zeros=False,
115
+ # ).quantize(model)
116
+
117
+ quantize_ (model ,
118
+ int8_dynamic_activation_intx_weight (
119
+ # group_size=group_size,
120
+ # nbit=bitwidth,
121
+ # has_weight_zeros=False,
122
+ weight_dtype = getattr (torch , f"int{ bitwidth } " ),
123
+ granularity = PerRow () if group_size == 0 else PerGroup (group_size ),
124
+ has_weight_zeros = False ,
125
+ ),
126
+ )
127
+ model = unwrap_tensor_subclass (model )
113
128
if verbose :
114
129
print ("quantized model:" , model )
115
130
return model
131
+
132
+ return model
116
133
elif qmode == "8da4w" :
117
134
# Check for required args
118
135
if group_size is None :
@@ -752,7 +769,7 @@ def get_quant_embedding_transform(args):
752
769
bitwidth , group_size = args .embedding_quantize .split (":" )[1 ].split ("," )
753
770
group_size = int (group_size )
754
771
bitwidth = int (bitwidth )
755
- _load_torchao_aten_lib (libname = "libtorchao_ops_aten" )
772
+ # _load_torchao_aten_lib(libname="libtorchao_ops_aten")
756
773
from torchao .experimental .quant_api import IntxWeightEmbeddingQuantizer
757
774
758
775
def _torchao_embedding_quantizer (model ):
0 commit comments