Skip to content

[ET-VK][ez] Fix linear weight int4 test due to change in ATen API #7739

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 17, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions backends/vulkan/test/op_tests/linear_weight_int4_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,38 @@ at::Tensor linear_weight_int4_reference_impl(
const size_t ndim = original_x_size.size();
const int64_t out_features = weights_4x2.size(0);
const at::Tensor x_flattened = x.reshape({-1, original_x_size[ndim - 1]});
const at::Tensor packed_weights =
at::_convert_weight_to_int4pack(weights_4x2, inner_k_tiles);
at::Tensor out = at::_weight_int4pack_mm(
x_flattened, packed_weights, groupsize, scales_and_zeros);
at::Tensor out = at::_weight_int4pack_mm_for_cpu(
x_flattened, weights_4x2, groupsize, scales_and_zeros);
std::vector<int64_t> out_shape(
original_x_size.begin(), original_x_size.end());
out_shape.at(ndim - 1) = out_features;
return out.reshape(out_shape);
}

at::Tensor unpack_weights_4x2(const at::Tensor& weights_4x2) {
std::vector<int64_t> weights_shape(weights_4x2.sizes().vec());
weights_shape[1] *= 2;

at::Tensor weights_unpacked =
at::empty(weights_shape, at::device(at::kCPU).dtype(at::kInt));

const int64_t N = weights_unpacked.size(0);
const int64_t K = weights_unpacked.size(1);

for (int n = 0; n < N; n++) {
for (int k = 0; k < K; k += 2) {
const uint8_t packed_val = weights_4x2[n][k / 2].item().to<uint8_t>();
const uint8_t second_val = packed_val & 0x0F;
const uint8_t first_val = (packed_val & 0xF0) >> 4;

weights_unpacked[n][k] = int(first_val);
weights_unpacked[n][k + 1] = int(second_val);
}
}

return weights_unpacked;
}

at::Tensor dequantize_and_linear(
const at::Tensor& x,
const at::Tensor& weights_4x2,
Expand Down Expand Up @@ -91,13 +113,18 @@ void test_reference_linear_int4(
at::Tensor x = at::rand({B, M, K}, at::device(at::kCPU).dtype(at::kFloat));
at::Tensor weights_4x2 =
at::randint(0, 256, {N, K / 2}, at::device(at::kCPU).dtype(at::kByte));
at::Tensor weights_int = unpack_weights_4x2(weights_4x2);

const int k_groups = K / group_size;
at::Tensor scales_and_zeros =
at::rand({k_groups, N, 2}, at::device(at::kCPU).dtype(at::kFloat));

at::Tensor out = linear_weight_int4_reference_impl(
x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles);
x,
at::_convert_weight_to_int4pack_for_cpu(weights_int, group_size),
group_size,
scales_and_zeros,
inner_k_tiles);

at::Tensor out_ref = dequantize_and_linear(
x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles);
Expand Down
Loading