Skip to content

Commit 6d17090

Browse files
pytorchbotSS-JIA
authored andcommitted
[ET-VK][ez] Fix linear weight int4 test due to change in ATen API (#7751)
Pull Request resolved: #7739 ## Context Recently the ATen API for 4-bit quantized linear has changed, so our test must adapt to the change in API. Concretely, the changes in API were: * The `_for_cpu` suffix was added to the operator name * The `_convert_weight_to_int4pack_mm` operator now expects unpacked 4-bit weights instead of a packed scheme where 2 4-bit values are packed into a single 8-bit value. ghstack-source-id: 261959346 @exported-using-ghexport Differential Revision: [D68333687](https://our.internmc.facebook.com/intern/diff/D68333687/) Co-authored-by: Stephen Jia <[email protected]>
1 parent 3055a5c commit 6d17090

File tree

1 file changed

+32
-5
lines changed

1 file changed

+32
-5
lines changed

backends/vulkan/test/op_tests/linear_weight_int4_test.cpp

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,38 @@ at::Tensor linear_weight_int4_reference_impl(
3030
const size_t ndim = original_x_size.size();
3131
const int64_t out_features = weights_4x2.size(0);
3232
const at::Tensor x_flattened = x.reshape({-1, original_x_size[ndim - 1]});
33-
const at::Tensor packed_weights =
34-
at::_convert_weight_to_int4pack(weights_4x2, inner_k_tiles);
35-
at::Tensor out = at::_weight_int4pack_mm(
36-
x_flattened, packed_weights, groupsize, scales_and_zeros);
33+
at::Tensor out = at::_weight_int4pack_mm_for_cpu(
34+
x_flattened, weights_4x2, groupsize, scales_and_zeros);
3735
std::vector<int64_t> out_shape(
3836
original_x_size.begin(), original_x_size.end());
3937
out_shape.at(ndim - 1) = out_features;
4038
return out.reshape(out_shape);
4139
}
4240

41+
at::Tensor unpack_weights_4x2(const at::Tensor& weights_4x2) {
42+
std::vector<int64_t> weights_shape(weights_4x2.sizes().vec());
43+
weights_shape[1] *= 2;
44+
45+
at::Tensor weights_unpacked =
46+
at::empty(weights_shape, at::device(at::kCPU).dtype(at::kInt));
47+
48+
const int64_t N = weights_unpacked.size(0);
49+
const int64_t K = weights_unpacked.size(1);
50+
51+
for (int n = 0; n < N; n++) {
52+
for (int k = 0; k < K; k += 2) {
53+
const uint8_t packed_val = weights_4x2[n][k / 2].item().to<uint8_t>();
54+
const uint8_t second_val = packed_val & 0x0F;
55+
const uint8_t first_val = (packed_val & 0xF0) >> 4;
56+
57+
weights_unpacked[n][k] = int(first_val);
58+
weights_unpacked[n][k + 1] = int(second_val);
59+
}
60+
}
61+
62+
return weights_unpacked;
63+
}
64+
4365
at::Tensor dequantize_and_linear(
4466
const at::Tensor& x,
4567
const at::Tensor& weights_4x2,
@@ -91,13 +113,18 @@ void test_reference_linear_int4(
91113
at::Tensor x = at::rand({B, M, K}, at::device(at::kCPU).dtype(at::kFloat));
92114
at::Tensor weights_4x2 =
93115
at::randint(0, 256, {N, K / 2}, at::device(at::kCPU).dtype(at::kByte));
116+
at::Tensor weights_int = unpack_weights_4x2(weights_4x2);
94117

95118
const int k_groups = K / group_size;
96119
at::Tensor scales_and_zeros =
97120
at::rand({k_groups, N, 2}, at::device(at::kCPU).dtype(at::kFloat));
98121

99122
at::Tensor out = linear_weight_int4_reference_impl(
100-
x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles);
123+
x,
124+
at::_convert_weight_to_int4pack_for_cpu(weights_int, group_size),
125+
group_size,
126+
scales_and_zeros,
127+
inner_k_tiles);
101128

102129
at::Tensor out_ref = dequantize_and_linear(
103130
x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles);

0 commit comments

Comments
 (0)