@@ -122,21 +122,7 @@ def linear_int4(input, weight_int4pack, scales_and_zeros, out_features, groupsiz
122
122
input .dtype
123
123
) # cast back to input.dtype
124
124
else :
125
- # copied from https://github.com/pytorch/pytorch/blob/0052943bee624c06d8c36a371efdf7b56972ad9e/torch/_meta_registrations.py#L3308
126
- def meta__weight_int4pack_mm (x , w , q_group_size , q_scale_and_zeros ):
127
- torch ._check (x .dim () == 2 , lambda : "x must be a 2D tensor" )
128
- torch ._check (w .dim () == 4 , lambda : "w must be a 4D tensor" )
129
- torch ._check (
130
- x .dtype in [torch .float32 , torch .float16 , torch .bfloat16 ],
131
- lambda : f"expected x to be f32/f16/bf16, got { x .dtype } " ,
132
- )
133
- torch ._check (
134
- w .dtype is torch .int32 ,
135
- lambda : f"expected w to be int32, got { w .dtype } " ,
136
- )
137
- return x .new_empty (x .size (0 ), w .size (0 ) * 8 , dtype = x .dtype )
138
-
139
- c = meta__weight_int4pack_mm (
125
+ c = torch .ops .aten ._weight_int4pack_mm_for_cpu (
140
126
input ,
141
127
weight_int4pack ,
142
128
groupsize ,
@@ -626,27 +612,10 @@ def load_model_and_state_dict(
626
612
q_uint8 = (q [::, ::2 ] << 4 | q [::, 1 ::2 ]).to (torch .uint8 )
627
613
628
614
if torch .device (device ).type == "cpu" :
629
- # Copied from https://github.com/pytorch/pytorch/blob/0052943bee624c06d8c36a371efdf7b56972ad9e/torch/_meta_registrations.py#L3273
630
- def meta__convert_weight_to_int4pack (w , inner_k_tiles ):
631
- torch ._check (w .dim () == 2 , lambda : "w must be a 2D tensor" )
632
- torch ._check (
633
- w .dtype is torch .uint8 ,
634
- lambda : f"expected w to be uint8, got { w .dtype } " ,
615
+ weight_int4pack = (
616
+ torch .ops .aten ._convert_weight_to_int4pack_for_cpu (
617
+ q , inner_k_tiles
635
618
)
636
- n = w .size (0 )
637
- k = w .size (1 ) * 2 # w is [n][k / 2] uint8
638
- return w .new_empty (
639
- (
640
- n // 8 ,
641
- k // (inner_k_tiles * 16 ),
642
- 32 ,
643
- inner_k_tiles // 2 ,
644
- ),
645
- dtype = torch .int32 ,
646
- )
647
-
648
- weight_int4pack = meta__convert_weight_to_int4pack (
649
- q_uint8 , inner_k_tiles
650
619
)
651
620
else :
652
621
weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (
0 commit comments