@@ -122,7 +122,21 @@ def linear_int4(input, weight_int4pack, scales_and_zeros, out_features, groupsiz
122
122
input .dtype
123
123
) # cast back to input.dtype
124
124
else :
125
- c = torch .ops .aten ._weight_int4pack_mm (
125
+ # copied from https://github.com/pytorch/pytorch/blob/0052943bee624c06d8c36a371efdf7b56972ad9e/torch/_meta_registrations.py#L3308
126
+ def meta__weight_int4pack_mm (x , w , q_group_size , q_scale_and_zeros ):
127
+ torch ._check (x .dim () == 2 , lambda : "x must be a 2D tensor" )
128
+ torch ._check (w .dim () == 4 , lambda : "w must be a 4D tensor" )
129
+ torch ._check (
130
+ x .dtype in [torch .float32 , torch .float16 , torch .bfloat16 ],
131
+ lambda : f"expected x to be f32/f16/bf16, got { x .dtype } " ,
132
+ )
133
+ torch ._check (
134
+ w .dtype is torch .int32 ,
135
+ lambda : f"expected w to be int32, got { w .dtype } " ,
136
+ )
137
+ return x .new_empty (x .size (0 ), w .size (0 ) * 8 , dtype = x .dtype )
138
+
139
+ c = meta__weight_int4pack_mm (
126
140
input ,
127
141
weight_int4pack ,
128
142
groupsize ,
@@ -610,10 +624,29 @@ def load_model_and_state_dict(
610
624
q , s , z = Q4_0 .unpack (t )
611
625
scales_and_zeros = pack_scales_and_zeros (s , z )
612
626
q_uint8 = (q [::, ::2 ] << 4 | q [::, 1 ::2 ]).to (torch .uint8 )
613
-
627
+
614
628
if torch .device (device ).type == "cpu" :
615
- weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack_for_cpu (
616
- q_uint8 , inner_k_tiles
629
+ # Copied from https://github.com/pytorch/pytorch/blob/0052943bee624c06d8c36a371efdf7b56972ad9e/torch/_meta_registrations.py#L3273
630
+ def meta__convert_weight_to_int4pack (w , inner_k_tiles ):
631
+ torch ._check (w .dim () == 2 , lambda : "w must be a 2D tensor" )
632
+ torch ._check (
633
+ w .dtype is torch .uint8 ,
634
+ lambda : f"expected w to be uint8, got { w .dtype } " ,
635
+ )
636
+ n = w .size (0 )
637
+ k = w .size (1 ) * 2 # w is [n][k / 2] uint8
638
+ return w .new_empty (
639
+ (
640
+ n // 8 ,
641
+ k // (inner_k_tiles * 16 ),
642
+ 32 ,
643
+ inner_k_tiles // 2 ,
644
+ ),
645
+ dtype = torch .int32 ,
646
+ )
647
+
648
+ weight_int4pack = meta__convert_weight_to_int4pack (
649
+ q_uint8 , inner_k_tiles
617
650
)
618
651
else :
619
652
weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (
0 commit comments