Skip to content

Commit 30c42ef

Browse files
authored
vulkan: workaround for AMD Windows driver 16 bit unpack8 bug (#12472)
1 parent af04481 commit 30c42ef

File tree

4 files changed

+12
-12
lines changed

4 files changed

+12
-12
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
8282
return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1]));
8383
}
8484
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
85-
const i8vec2 v0 = unpack8(data_a_packed16[a_offset + ib].qs[iqs/2]);
86-
const i8vec2 v1 = unpack8(data_a_packed16[a_offset + ib].qs[iqs/2 + 1]);
85+
const i8vec2 v0 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2])).xy; // vec4 used due to #12147
86+
const i8vec2 v1 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2 + 1])).xy;
8787
return vec4(v0.x, v0.y, v1.x, v1.y);
8888
}
8989
#endif

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
1919
const float db = d * (0.5 + scale) * 0.25;
2020

2121
const uint qh = data_a[ibi].qh[ib32];
22-
const u8vec2 qs16 = unpack8(data_a_packed16[ibi].qs[itid]);
23-
const u8vec2 sign16 = unpack8(data_a_packed16[ibi].qs[QUANT_K / 16 + itid]);
22+
const u8vec2 qs16 = unpack8(uint32_t(data_a_packed16[ibi].qs[itid])).xy; // vec4 used due to #12147
23+
const u8vec2 sign16 = unpack8(uint32_t(data_a_packed16[ibi].qs[QUANT_K / 16 + itid])).xy;
2424
[[unroll]] for (uint l = 0; l < 2; ++l) {
2525
const uint8_t sign = sign16[l];
2626
const uint qs = qs16[l] | ((qh << (8 - nibble_shift - 2 * l)) & 0x300);

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32,
2121
sum[j] = 0.0;
2222
}
2323
[[unroll]] for (uint l = 0; l < 4; ++l) {
24-
const u8vec2 qs = unpack8(data_a_packed16[ibi].qs[4 * ib32 + l]);
24+
const u8vec2 qs = unpack8(uint32_t(data_a_packed16[ibi].qs[4 * ib32 + l])).xy; // vec4 used due to #12147
2525
const uint sign = data_a[ibi].signs[4 * ib32 + l];
2626
const vec4 grid0 = vec4(unpack8(iq3s_grid[qs.x | ((qh << (8 - 2*l)) & 0x100)]));
2727
const vec4 grid1 = vec4(unpack8(iq3s_grid[qs.y | ((qh << (7 - 2*l)) & 0x100)]));

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -336,8 +336,8 @@ void main() {
336336
const uint iqs = idx & 0x07;
337337

338338
const float d = float(data_a_packed16[ib].d);
339-
const i8vec2 v0 = unpack8(data_a_packed16[ib].qs[2*iqs]);
340-
const i8vec2 v1 = unpack8(data_a_packed16[ib].qs[2*iqs + 1]);
339+
const i8vec2 v0 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs])).xy; // vec4 used due to #12147
340+
const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy;
341341
const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;
342342

343343
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
@@ -544,7 +544,7 @@ void main() {
544544
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
545545
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
546546
const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1));
547-
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
547+
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
548548

549549
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
550550
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
@@ -564,7 +564,7 @@ void main() {
564564
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
565565
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
566566
const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1));
567-
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
567+
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
568568

569569
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
570570
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
@@ -586,7 +586,7 @@ void main() {
586586
const float db = d * 0.25 * (0.5 + scale);
587587
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
588588
const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
589-
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid));
589+
const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147
590590

591591
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
592592
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
@@ -611,7 +611,7 @@ void main() {
611611
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
612612
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
613613
const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
614-
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
614+
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
615615

616616
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
617617
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
@@ -631,7 +631,7 @@ void main() {
631631
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
632632
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
633633
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
634-
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
634+
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
635635

636636
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
637637
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);

0 commit comments

Comments
 (0)