Skip to content

[ET-VK] Allow int4 linear to execute without 8bit buffer support #10030

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/gen_vulkan_spv.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def buffer_gvec_type(dtype: str, n: int) -> str:

if dtype == "float":
return f"vec{n}"
if dtype == "uint":
return f"uvec{n}"
elif dtype == "half":
return f"f16vec{n}"
elif dtype == "int":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,52 @@

#define PRECISION ${PRECISION}

${define_required_extensions("uint8")}
${define_required_extensions("int8")}
$if not NO_INT8_BUFFERS:
${define_required_extensions("uint8")}
$if STORAGE == "buffer":
${define_required_extensions("int8")}

layout(std430) buffer;

${layout_declare_tensor(B, "w", "t_qmat2", "uint8", STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "nchw_4x2", "uint8", "buffer")}
$if NO_INT8_BUFFERS:
${layout_declare_tensor(B, "r", "nchw_4x2", "uint", "buffer")}
$else:
${layout_declare_tensor(B, "r", "nchw_4x2", "uint8", "buffer")}

layout(push_constant) uniform restrict Block {
ivec4 qmat2_sizes;
};

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

uint8_t get_first(const uint8_t packed) {
return uint8_t((packed & 0xF0) >> 4);
$if NO_INT8_BUFFERS:
#define BUF_T uint
$else:
#define BUF_T uint8_t

$if STORAGE == "buffer":
#define UVEC4_T u8vec4
$else:
#define UVEC4_T uvec4

uint get_first(const BUF_T packed) {
return (packed & 0xF0) >> 4;
}

uint8_t get_second(const uint8_t packed) {
return uint8_t(packed & 0x0F);
uint get_second(const BUF_T packed) {
return packed & 0x0F;
}

uint8_t combine(const uint8_t first, const uint8_t second) {
return uint8_t(first << 4 | second);
uint combine(const uint first, const uint second) {
return (first << 4 | second);
}

$if NO_INT8_BUFFERS:
uint extract_comp(const uint packed4, const uint idx) {
return (packed4 >> (idx * 8)) & 0xFF;
}

/*
* This shader packs the weight tensor into a texture.
*
Expand Down Expand Up @@ -102,25 +122,32 @@ void main() {
int in_numcols = qmat2_sizes.y;
int in_num_int8_cols = qmat2_sizes.y >> 1;

uint8_t in_vals[8][2];
uint in_vals[8][2];
for (int r = 0; r < 8; ++r) {
if (in_row + r < in_numrows) {
uint8_t in_val_packed = nchw_4x2[(in_row + r) * in_num_int8_cols + in_int8_col];
uint scalar_idx = (in_row + r) * in_num_int8_cols + in_int8_col;
$if NO_INT8_BUFFERS:
BUF_T in_val_packed_texel = nchw_4x2[scalar_idx >> 2];
const uint packed_idx = scalar_idx % 4;
uint in_val_packed = extract_comp(in_val_packed_texel, packed_idx);
$else:
BUF_T in_val_packed = nchw_4x2[scalar_idx];

in_vals[r][0] = get_first(in_val_packed);
in_vals[r][1] = get_second(in_val_packed);
} else {
in_vals[r][0] = uint8_t(0);
in_vals[r][1] = uint8_t(0);
in_vals[r][0] = uint(0);
in_vals[r][1] = uint(0);
}
}

u8vec4 out_tex_1 = u8vec4(
UVEC4_T out_tex_1 = UVEC4_T(
combine(in_vals[0][0], in_vals[4][0]),
combine(in_vals[1][0], in_vals[5][0]),
combine(in_vals[2][0], in_vals[6][0]),
combine(in_vals[3][0], in_vals[7][0]));

u8vec4 out_tex_2 = u8vec4(
UVEC4_T out_tex_2 = UVEC4_T(
combine(in_vals[0][1], in_vals[4][1]),
combine(in_vals[1][1], in_vals[5][1]),
combine(in_vals[2][1], in_vals[6][1]),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
pack_int4_linear_weight_transposed_interleaved:
parameter_names_with_default_values:
STORAGE: texture2d
generate_variant_forall:
STORAGE:
- VALUE: texture2d
- VALUE: buffer
NO_INT8_BUFFERS: false
shader_variants:
- NAME: pack_int4_linear_weight_transposed_interleaved
- NAME: pack_int4_linear_weight_transposed_interleaved_texture2d
- NAME: pack_int4_linear_weight_transposed_interleaved_buffer
STORAGE: buffer
- NAME: pack_int4_linear_weight_transposed_interleaved_nobitw8buffer_texture2d
NO_INT8_BUFFERS: true
3 changes: 2 additions & 1 deletion backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}

${define_required_extensions(DTYPE)}
${define_required_extensions("int8")}
$if WEIGHT_STORAGE == "buffer":
${define_required_extensions("uint8")}

layout(std430) buffer;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ void check_q_4w_linear_args(
const ValueRef group_size,
const ValueRef scales_and_zeros,
const ValueRef out) {
VK_CHECK_COND(graph.int16_shader_types_enabled());
VK_CHECK_COND(graph.int8_buffers_enabled());

VK_CHECK_COND(graph.val_is_tensor(mat1));
VK_CHECK_COND(graph.val_is_tref(mat2_data));
VK_CHECK_COND(graph.val_is_tref(scales_and_zeros));
Expand Down Expand Up @@ -97,7 +94,10 @@ ValueRef prepack_int4_linear_weight_transposed_interleaved(
global_wg_size = graph.logical_limits_of(qmat2);
global_wg_size[1] = utils::div_up(global_wg_size[1], uint32_t(2));

std::string kernel_name = "pack_int4_linear_weight_transposed_interleaved";
std::string kernel_name =
graph.context()->adapter_ptr()->has_full_int8_buffers_support()
? "pack_int4_linear_weight_transposed_interleaved"
: "pack_int4_linear_weight_transposed_interleaved_nobitw8buffer";
add_storage_type_suffix(kernel_name, storage_type);

graph.prepack_nodes().emplace_back(new PrepackNode(
Expand Down
5 changes: 0 additions & 5 deletions backends/vulkan/test/op_tests/linear_weight_int4_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,6 @@ TEST(VulkanInt4LinearTest, test_reference_impl) {
}

TEST(VulkanInt4LinearTest, test_vulkan_impl) {
if (!vkcompute::api::context()
->adapter_ptr()
->has_full_int8_buffers_support()) {
GTEST_SKIP();
}
test_vulkan_linear_int4(
/*B = */ 1,
/*M = */ 4,
Expand Down
Loading