@@ -30,16 +30,38 @@ at::Tensor linear_weight_int4_reference_impl(
30
30
const size_t ndim = original_x_size.size ();
31
31
const int64_t out_features = weights_4x2.size (0 );
32
32
const at::Tensor x_flattened = x.reshape ({-1 , original_x_size[ndim - 1 ]});
33
- const at::Tensor packed_weights =
34
- at::_convert_weight_to_int4pack (weights_4x2, inner_k_tiles);
35
- at::Tensor out = at::_weight_int4pack_mm (
36
- x_flattened, packed_weights, groupsize, scales_and_zeros);
33
+ at::Tensor out = at::_weight_int4pack_mm_for_cpu (
34
+ x_flattened, weights_4x2, groupsize, scales_and_zeros);
37
35
std::vector<int64_t > out_shape (
38
36
original_x_size.begin (), original_x_size.end ());
39
37
out_shape.at (ndim - 1 ) = out_features;
40
38
return out.reshape (out_shape);
41
39
}
42
40
41
+ at::Tensor unpack_weights_4x2 (const at::Tensor& weights_4x2) {
42
+ std::vector<int64_t > weights_shape (weights_4x2.sizes ().vec ());
43
+ weights_shape[1 ] *= 2 ;
44
+
45
+ at::Tensor weights_unpacked =
46
+ at::empty (weights_shape, at::device (at::kCPU ).dtype (at::kInt ));
47
+
48
+ const int64_t N = weights_unpacked.size (0 );
49
+ const int64_t K = weights_unpacked.size (1 );
50
+
51
+ for (int n = 0 ; n < N; n++) {
52
+ for (int k = 0 ; k < K; k += 2 ) {
53
+ const uint8_t packed_val = weights_4x2[n][k / 2 ].item ().to <uint8_t >();
54
+ const uint8_t second_val = packed_val & 0x0F ;
55
+ const uint8_t first_val = (packed_val & 0xF0 ) >> 4 ;
56
+
57
+ weights_unpacked[n][k] = int (first_val);
58
+ weights_unpacked[n][k + 1 ] = int (second_val);
59
+ }
60
+ }
61
+
62
+ return weights_unpacked;
63
+ }
64
+
43
65
at::Tensor dequantize_and_linear (
44
66
const at::Tensor& x,
45
67
const at::Tensor& weights_4x2,
@@ -91,13 +113,18 @@ void test_reference_linear_int4(
91
113
at::Tensor x = at::rand ({B, M, K}, at::device (at::kCPU ).dtype (at::kFloat ));
92
114
at::Tensor weights_4x2 =
93
115
at::randint (0 , 256 , {N, K / 2 }, at::device (at::kCPU ).dtype (at::kByte ));
116
+ at::Tensor weights_int = unpack_weights_4x2 (weights_4x2);
94
117
95
118
const int k_groups = K / group_size;
96
119
at::Tensor scales_and_zeros =
97
120
at::rand ({k_groups, N, 2 }, at::device (at::kCPU ).dtype (at::kFloat ));
98
121
99
122
at::Tensor out = linear_weight_int4_reference_impl (
100
- x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles);
123
+ x,
124
+ at::_convert_weight_to_int4pack_for_cpu (weights_int, group_size),
125
+ group_size,
126
+ scales_and_zeros,
127
+ inner_k_tiles);
101
128
102
129
at::Tensor out_ref = dequantize_and_linear (
103
130
x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles);
0 commit comments