Skip to content

vulkan: further optimize q5_k mul_mat_vec #10479

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 28 additions & 24 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ void main() {
const uint q_offset = 32*v_im + l0;
const uint y_offset = 64*v_im + l0;

const uint8_t hm1 = uint8_t(1 << (2*v_im));
const uint8_t hm2 = uint8_t(hm1 << 4);

FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp

[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
Expand Down Expand Up @@ -71,6 +68,18 @@ void main() {
uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F;
uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F;

uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8]));

uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4;
uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3;
uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010) << 0;
uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1;

qs0_16_u32_lo4 += qs0_16_lo4_offset16;
qs0_16_u32_hi4 += qs0_16_hi4_offset16;
qs64_80_u32_lo4 += qs64_80_lo4_offset16;
qs64_80_u32_hi4 += qs64_80_hi4_offset16;

uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4));
uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4));
uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4));
Expand Down Expand Up @@ -102,31 +111,26 @@ void main() {
B_TYPE_VEC2 by232 = data_b_v2[(b_offset + y2_idx) / 2 + 16];
B_TYPE_VEC2 by248 = data_b_v2[(b_offset + y2_idx) / 2 + 24];

uint32_t qh0 = data_a_packed16[ib0 + i].qh[l0 / 2];
uint32_t qh1 = qh0 >> 8;
uint32_t qh16 = data_a_packed16[ib0 + i].qh[l0 / 2 + 8];
uint32_t qh17 = qh16 >> 8;

const FLOAT_TYPE sx =
fma(FLOAT_TYPE(by10.x), (q4_0 + (((qh0 & hm1) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(by10.y), (q4_1 + (((qh1 & hm1) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(by116.x), (q4_2 + (((qh16 & hm1) != 0) ? 16 : 0)),
FLOAT_TYPE(by116.y) * (q4_3 + (((qh17 & hm1) != 0) ? 16 : 0)))));
fma(FLOAT_TYPE(by10.x), q4_0,
fma(FLOAT_TYPE(by10.y), q4_1,
fma(FLOAT_TYPE(by116.x), q4_2,
FLOAT_TYPE(by116.y) * q4_3)));
const FLOAT_TYPE sy =
fma(FLOAT_TYPE(by132.x), (q4_4 + (((qh0 & (hm1 << 1)) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(by132.y), (q4_5 + (((qh1 & (hm1 << 1)) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(by148.x), (q4_6 + (((qh16 & (hm1 << 1)) != 0) ? 16 : 0)),
FLOAT_TYPE(by148.y) * (q4_7 + (((qh17 & (hm1 << 1)) != 0) ? 16 : 0)))));
fma(FLOAT_TYPE(by132.x), q4_4,
fma(FLOAT_TYPE(by132.y), q4_5,
fma(FLOAT_TYPE(by148.x), q4_6,
FLOAT_TYPE(by148.y) * q4_7)));
const FLOAT_TYPE sz =
fma(FLOAT_TYPE(by20.x), (q4_8 + (((qh0 & hm2) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(by20.y), (q4_9 + (((qh1 & hm2) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(by216.x), (q4_10 + (((qh16 & hm2) != 0) ? 16 : 0)),
FLOAT_TYPE(by216.y) * (q4_11 + (((qh17 & hm2) != 0) ? 16 : 0)))));
fma(FLOAT_TYPE(by20.x), q4_8,
fma(FLOAT_TYPE(by20.y), q4_9,
fma(FLOAT_TYPE(by216.x), q4_10,
FLOAT_TYPE(by216.y) * q4_11)));
const FLOAT_TYPE sw =
fma(FLOAT_TYPE(by232.x), (q4_12 + (((qh0 & (hm2 << 1)) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(by232.y), (q4_13 + (((qh1 & (hm2 << 1)) != 0) ? 16 : 0)),
fma(FLOAT_TYPE(by248.x), (q4_14 + (((qh16 & (hm2 << 1)) != 0) ? 16 : 0)),
FLOAT_TYPE(by248.y) * (q4_15 + (((qh17 & (hm2 << 1)) != 0) ? 16 : 0)))));
fma(FLOAT_TYPE(by232.x), q4_12,
fma(FLOAT_TYPE(by232.y), q4_13,
fma(FLOAT_TYPE(by248.x), q4_14,
FLOAT_TYPE(by248.y) * q4_15)));
const FLOAT_TYPE smin =
fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
Expand Down
Loading