Skip to content

Commit 28c458a

Browse files
vulkan: increase LOAD_VEC_A to 8 (IQ1/IQ2) or 4 (IQ3)
1 parent da0d698 commit 28c458a

File tree

2 files changed

+83
-76
lines changed

2 files changed

+83
-76
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 81 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -483,34 +483,26 @@ void main() {
483483
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
484484
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
485485

486-
const uint ib = idx / 128; // 2 values per idx
487-
const uint ib32 = (idx % 128) / 16; // 0..7
488-
const uint ib8 = (idx % 128) / 4;
489-
const int i8 = 2 * int(idx % 4);
486+
const uint ib = idx / 32; // 8 values per idx
487+
const uint ib32 = (idx % 32) / 4; // 0..7
488+
const uint ib8 = idx % 32;
490489

491490
const float d = float(data_a[ib].d);
492491
const uint qh = data_a[ib].qh[ib32];
493492
const uint qs = data_a[ib].qs[ib8];
494493
const float dl = d * (2 * bitfieldExtract(qh, 12, 3) + 1);
495494
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
496495
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
497-
498-
const ivec2 gvec = ivec2(
499-
bitfieldExtract(grid, 2 * (i8), 2),
500-
bitfieldExtract(grid, 2 * (i8 + 1), 2)
501-
);
502-
const vec2 v = dl * (vec2(gvec) + delta);
503-
504-
buf_a[buf_idx ] = BUF_TYPE(v.x);
505-
buf_a[buf_idx + 1] = BUF_TYPE(v.y);
496+
[[unroll]] for (int k = 0; k < 8; ++k) {
497+
buf_a[buf_idx + k] = BUF_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
498+
}
506499
#elif defined(DATA_A_IQ1_M)
507500
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
508501
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
509502

510-
const uint ib = idx / 128; // 2 values per idx
511-
const uint ib8 = (idx % 128) / 4;
503+
const uint ib = idx / 32; // 8 values per idx
504+
const uint ib8 = idx % 32;
512505
const uint ib16 = ib8 / 2;
513-
const int i8 = 2 * int(idx % 4);
514506

515507
const uint16_t[4] scales = data_a[ib].scales;
516508
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
@@ -521,21 +513,16 @@ void main() {
521513
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
522514
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
523515
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
524-
const ivec2 gvec = ivec2(
525-
bitfieldExtract(grid, 2 * (i8), 2),
526-
bitfieldExtract(grid, 2 * (i8 + 1), 2)
527-
);
528-
const vec2 v = dl * (vec2(gvec) + delta);
529-
530-
buf_a[buf_idx ] = BUF_TYPE(v.x);
531-
buf_a[buf_idx + 1] = BUF_TYPE(v.y);
516+
[[unroll]] for (int k = 0; k < 8; ++k) {
517+
buf_a[buf_idx + k] = BUF_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
518+
}
532519
#elif defined(DATA_A_IQ2_XXS)
533520
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
534521
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
535522

536-
const uint ib = idx / 128; // 2 values per idx
537-
const uint ib32 = (idx % 128) / 16; // 0..7
538-
const uint ib8 = (idx / 4) % 4;
523+
const uint ib = idx / 32; // 8 values per idx
524+
const uint ib32 = (idx % 32) / 4; // 0..7
525+
const uint ib8 = idx % 4;
539526

540527
const float d = float(data_a[ib].d);
541528
const uint qs = data_a[ib].qs[8 * ib32 + ib8];
@@ -545,63 +532,81 @@ void main() {
545532
data_a[ib].qs[8*ib32 + 6],
546533
data_a[ib].qs[8*ib32 + 7]
547534
));
548-
const float db = d * 0.25 * (0.5 + (signs >> 28));
535+
const BUF_TYPE db = BUF_TYPE(d * 0.25 * (0.5 + (signs >> 28)));
549536
const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
550-
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
551-
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
552-
const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1));
553-
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
554-
555-
buf_a[buf_idx ] = BUF_TYPE(v.x);
556-
buf_a[buf_idx + 1] = BUF_TYPE(v.y);
537+
const uint sign = sign7 | (bitCount(sign7) << 7);
538+
const uvec2 grid = iq2xxs_grid[qs];
539+
const vec4 grid0 = vec4(unpack8(grid.x));
540+
const vec4 grid1 = vec4(unpack8(grid.y));
541+
542+
buf_a[buf_idx ] = db * BUF_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
543+
buf_a[buf_idx + 1] = db * BUF_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
544+
buf_a[buf_idx + 2] = db * BUF_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
545+
buf_a[buf_idx + 3] = db * BUF_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
546+
buf_a[buf_idx + 4] = db * BUF_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
547+
buf_a[buf_idx + 5] = db * BUF_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
548+
buf_a[buf_idx + 6] = db * BUF_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
549+
buf_a[buf_idx + 7] = db * BUF_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
557550
#elif defined(DATA_A_IQ2_XS)
558551
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
559552
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
560553

561-
const uint ib = idx / 128; // 2 values per idx
562-
const uint ib32 = (idx % 128) / 16; // 0..7
563-
const uint ib8 = (idx / 4) % 4; // 0..3
554+
const uint ib = idx / 32; // 8 values per idx
555+
const uint ib32 = (idx % 32) / 4; // 0..7
556+
const uint ib8 = idx % 4; // 0..3
564557

565558
const float d = float(data_a[ib].d);
566559
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
567-
const float db = d * 0.25 * (0.5 + scale);
560+
const BUF_TYPE db = BUF_TYPE(d * 0.25 * (0.5 + scale));
568561
const uint qs = data_a[ib].qs[4 * ib32 + ib8];
569562
const uint sign7 = qs >> 9;
570-
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
571-
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
572-
const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1));
573-
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
574-
575-
buf_a[buf_idx ] = BUF_TYPE(v.x);
576-
buf_a[buf_idx + 1] = BUF_TYPE(v.y);
563+
const uint sign = sign7 | (bitCount(sign7) << 7);
564+
const uvec2 grid = iq2xs_grid[qs & 511];
565+
const vec4 grid0 = vec4(unpack8(grid.x));
566+
const vec4 grid1 = vec4(unpack8(grid.y));
567+
568+
buf_a[buf_idx ] = db * BUF_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
569+
buf_a[buf_idx + 1] = db * BUF_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
570+
buf_a[buf_idx + 2] = db * BUF_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
571+
buf_a[buf_idx + 3] = db * BUF_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
572+
buf_a[buf_idx + 4] = db * BUF_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
573+
buf_a[buf_idx + 5] = db * BUF_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
574+
buf_a[buf_idx + 6] = db * BUF_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
575+
buf_a[buf_idx + 7] = db * BUF_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
577576
#elif defined(DATA_A_IQ2_S)
578577
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
579578
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
580579

581-
const uint ib = idx / 128; // 2 values per idx
582-
const uint ib8 = (idx % 128) / 4; // 0..31
583-
const uint ib32 = ib8 / 4; // 0..7
580+
const uint ib = idx / 32; // 8 values per idx
581+
const uint ib8 = idx % 32; // 0..31
582+
const uint ib32 = ib8 / 4; // 0..7
584583

585584
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
586585
const uint qs = data_a[ib].qs[ib8];
587586
const uint qh = data_a[ib].qh[ib32];
588587
const uint qhshift = 2 * (ib8 % 4);
589-
const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4));
588+
const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8];
590589

591590
const float d = float(data_a[ib].d);
592-
const float db = d * 0.25 * (0.5 + scale);
593-
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
594-
const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
595-
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid));
596-
597-
buf_a[buf_idx ] = BUF_TYPE(v.x);
598-
buf_a[buf_idx + 1] = BUF_TYPE(v.y);
591+
const BUF_TYPE db = BUF_TYPE(d * 0.25 * (0.5 + scale));
592+
const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)];
593+
const vec4 grid0 = vec4(unpack8(grid.x));
594+
const vec4 grid1 = vec4(unpack8(grid.y));
595+
596+
buf_a[buf_idx ] = db * BUF_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
597+
buf_a[buf_idx + 1] = db * BUF_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
598+
buf_a[buf_idx + 2] = db * BUF_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
599+
buf_a[buf_idx + 3] = db * BUF_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
600+
buf_a[buf_idx + 4] = db * BUF_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
601+
buf_a[buf_idx + 5] = db * BUF_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
602+
buf_a[buf_idx + 6] = db * BUF_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
603+
buf_a[buf_idx + 7] = db * BUF_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
599604
#elif defined(DATA_A_IQ3_XXS)
600605
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
601606
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
602607

603-
const uint ib = idx / 128; // 2 values per idx
604-
const uint iqs = (idx % 128) / 2; // 0..63
608+
const uint ib = idx / 64; // 4 values per idx
609+
const uint iqs = idx % 64; // 0..63
605610
const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
606611

607612
const float d = float(data_a[ib].d);
@@ -614,33 +619,35 @@ void main() {
614619
));
615620
const float db = d * 0.5 * (0.5 + (signs >> 28));
616621
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
617-
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
618-
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
619-
const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
620-
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
621-
622-
buf_a[buf_idx ] = BUF_TYPE(v.x);
623-
buf_a[buf_idx + 1] = BUF_TYPE(v.y);
622+
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2));
623+
const uint grid = iq3xxs_grid[qs];
624+
const vec4 v = db * vec4(unpack8(grid));
625+
626+
buf_a[buf_idx ] = BUF_TYPE((sign & 1) != 0 ? -v.x : v.x);
627+
buf_a[buf_idx + 1] = BUF_TYPE((sign & 2) != 0 ? -v.y : v.y);
628+
buf_a[buf_idx + 2] = BUF_TYPE((sign & 4) != 0 ? -v.z : v.z);
629+
buf_a[buf_idx + 3] = BUF_TYPE((sign & 8) != 0 ? -v.w : v.w);
624630
#elif defined(DATA_A_IQ3_S)
625631
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
626632
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
627633

628-
const uint ib = idx / 128; // 2 values per idx
629-
const uint iqs = (idx % 128) / 2; // 0..63
634+
const uint ib = idx / 64; // 4 values per idx
635+
const uint iqs = idx % 64; // 0..63
630636
const uint iqh = iqs / 8;
631637

632638
const float d = float(data_a[ib].d);
633639
const uint qs = data_a[ib].qs[iqs];
634640
const uint qh = data_a[ib].qh[iqh];
635-
const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4)));
641+
const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2)));
636642
const uint scale = data_a[ib].scales[iqs / 16];
637-
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
638643
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
639-
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
640-
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
644+
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
645+
const vec4 v = db * vec4(unpack8(grid));
641646

642-
buf_a[buf_idx ] = BUF_TYPE(v.x);
643-
buf_a[buf_idx + 1] = BUF_TYPE(v.y);
647+
buf_a[buf_idx ] = BUF_TYPE((sign & 1) != 0 ? -v.x : v.x);
648+
buf_a[buf_idx + 1] = BUF_TYPE((sign & 2) != 0 ? -v.y : v.y);
649+
buf_a[buf_idx + 2] = BUF_TYPE((sign & 4) != 0 ? -v.z : v.z);
650+
buf_a[buf_idx + 3] = BUF_TYPE((sign & 8) != 0 ? -v.w : v.w);
644651
#elif defined(DATA_A_IQ4_XS)
645652
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
646653
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,9 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
326326

327327
for (const auto& tname : type_names) {
328328
std::string load_vec_quant = "2";
329-
if ((tname == "q4_0") || (tname == "q4_1"))
329+
if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
330330
load_vec_quant = "8";
331-
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
331+
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl"))
332332
load_vec_quant = "4";
333333

334334
std::string data_a_key = "DATA_A_" + to_uppercase(tname);

0 commit comments

Comments
 (0)