@@ -483,34 +483,26 @@ void main() {
483
483
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
484
484
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
485
485
486
- const uint ib = idx / 128; // 2 values per idx
487
- const uint ib32 = (idx % 128) / 16; // 0..7
488
- const uint ib8 = (idx % 128) / 4;
489
- const int i8 = 2 * int(idx % 4);
486
+ const uint ib = idx / 32; // 8 values per idx
487
+ const uint ib32 = (idx % 32) / 4; // 0..7
488
+ const uint ib8 = idx % 32;
490
489
491
490
const float d = float(data_a[ib].d);
492
491
const uint qh = data_a[ib].qh[ib32];
493
492
const uint qs = data_a[ib].qs[ib8];
494
493
const float dl = d * (2 * bitfieldExtract(qh, 12, 3) + 1);
495
494
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
496
495
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
497
-
498
- const ivec2 gvec = ivec2(
499
- bitfieldExtract(grid, 2 * (i8), 2),
500
- bitfieldExtract(grid, 2 * (i8 + 1), 2)
501
- );
502
- const vec2 v = dl * (vec2(gvec) + delta);
503
-
504
- buf_a[buf_idx ] = BUF_TYPE(v.x);
505
- buf_a[buf_idx + 1] = BUF_TYPE(v.y);
496
+ [[unroll]] for (int k = 0; k < 8; ++k) {
497
+ buf_a[buf_idx + k] = BUF_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
498
+ }
506
499
#elif defined(DATA_A_IQ1_M)
507
500
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
508
501
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
509
502
510
- const uint ib = idx / 128 ; // 2 values per idx
511
- const uint ib8 = ( idx % 128) / 4 ;
503
+ const uint ib = idx / 32 ; // 8 values per idx
504
+ const uint ib8 = idx % 32 ;
512
505
const uint ib16 = ib8 / 2;
513
- const int i8 = 2 * int(idx % 4);
514
506
515
507
const uint16_t[4] scales = data_a[ib].scales;
516
508
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
@@ -521,21 +513,16 @@ void main() {
521
513
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
522
514
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
523
515
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
524
- const ivec2 gvec = ivec2(
525
- bitfieldExtract(grid, 2 * (i8), 2),
526
- bitfieldExtract(grid, 2 * (i8 + 1), 2)
527
- );
528
- const vec2 v = dl * (vec2(gvec) + delta);
529
-
530
- buf_a[buf_idx ] = BUF_TYPE(v.x);
531
- buf_a[buf_idx + 1] = BUF_TYPE(v.y);
516
+ [[unroll]] for (int k = 0; k < 8; ++k) {
517
+ buf_a[buf_idx + k] = BUF_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
518
+ }
532
519
#elif defined(DATA_A_IQ2_XXS)
533
520
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
534
521
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
535
522
536
- const uint ib = idx / 128 ; // 2 values per idx
537
- const uint ib32 = (idx % 128 ) / 16 ; // 0..7
538
- const uint ib8 = ( idx / 4) % 4;
523
+ const uint ib = idx / 32 ; // 8 values per idx
524
+ const uint ib32 = (idx % 32 ) / 4 ; // 0..7
525
+ const uint ib8 = idx % 4;
539
526
540
527
const float d = float(data_a[ib].d);
541
528
const uint qs = data_a[ib].qs[8 * ib32 + ib8];
@@ -545,63 +532,81 @@ void main() {
545
532
data_a[ib].qs[8*ib32 + 6],
546
533
data_a[ib].qs[8*ib32 + 7]
547
534
));
548
- const float db = d * 0.25 * (0.5 + (signs >> 28));
535
+ const BUF_TYPE db = BUF_TYPE( d * 0.25 * (0.5 + (signs >> 28) ));
549
536
const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
550
- const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
551
- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
552
- const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1));
553
- const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
554
-
555
- buf_a[buf_idx ] = BUF_TYPE(v.x);
556
- buf_a[buf_idx + 1] = BUF_TYPE(v.y);
537
+ const uint sign = sign7 | (bitCount(sign7) << 7);
538
+ const uvec2 grid = iq2xxs_grid[qs];
539
+ const vec4 grid0 = vec4(unpack8(grid.x));
540
+ const vec4 grid1 = vec4(unpack8(grid.y));
541
+
542
+ buf_a[buf_idx ] = db * BUF_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
543
+ buf_a[buf_idx + 1] = db * BUF_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
544
+ buf_a[buf_idx + 2] = db * BUF_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
545
+ buf_a[buf_idx + 3] = db * BUF_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
546
+ buf_a[buf_idx + 4] = db * BUF_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
547
+ buf_a[buf_idx + 5] = db * BUF_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
548
+ buf_a[buf_idx + 6] = db * BUF_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
549
+ buf_a[buf_idx + 7] = db * BUF_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
557
550
#elif defined(DATA_A_IQ2_XS)
558
551
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
559
552
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
560
553
561
- const uint ib = idx / 128 ; // 2 values per idx
562
- const uint ib32 = (idx % 128 ) / 16; // 0..7
563
- const uint ib8 = ( idx / 4) % 4; // 0..3
554
+ const uint ib = idx / 32 ; // 8 values per idx
555
+ const uint ib32 = (idx % 32 ) / 4; // 0..7
556
+ const uint ib8 = idx % 4; // 0..3
564
557
565
558
const float d = float(data_a[ib].d);
566
559
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
567
- const float db = d * 0.25 * (0.5 + scale);
560
+ const BUF_TYPE db = BUF_TYPE( d * 0.25 * (0.5 + scale) );
568
561
const uint qs = data_a[ib].qs[4 * ib32 + ib8];
569
562
const uint sign7 = qs >> 9;
570
- const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
571
- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
572
- const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1));
573
- const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
574
-
575
- buf_a[buf_idx ] = BUF_TYPE(v.x);
576
- buf_a[buf_idx + 1] = BUF_TYPE(v.y);
563
+ const uint sign = sign7 | (bitCount(sign7) << 7);
564
+ const uvec2 grid = iq2xs_grid[qs & 511];
565
+ const vec4 grid0 = vec4(unpack8(grid.x));
566
+ const vec4 grid1 = vec4(unpack8(grid.y));
567
+
568
+ buf_a[buf_idx ] = db * BUF_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
569
+ buf_a[buf_idx + 1] = db * BUF_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
570
+ buf_a[buf_idx + 2] = db * BUF_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
571
+ buf_a[buf_idx + 3] = db * BUF_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
572
+ buf_a[buf_idx + 4] = db * BUF_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
573
+ buf_a[buf_idx + 5] = db * BUF_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
574
+ buf_a[buf_idx + 6] = db * BUF_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
575
+ buf_a[buf_idx + 7] = db * BUF_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
577
576
#elif defined(DATA_A_IQ2_S)
578
577
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
579
578
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
580
579
581
- const uint ib = idx / 128 ; // 2 values per idx
582
- const uint ib8 = ( idx % 128) / 4 ; // 0..31
583
- const uint ib32 = ib8 / 4; // 0..7
580
+ const uint ib = idx / 32 ; // 8 values per idx
581
+ const uint ib8 = idx % 32 ; // 0..31
582
+ const uint ib32 = ib8 / 4; // 0..7
584
583
585
584
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
586
585
const uint qs = data_a[ib].qs[ib8];
587
586
const uint qh = data_a[ib].qh[ib32];
588
587
const uint qhshift = 2 * (ib8 % 4);
589
- const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4)) ;
588
+ const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8];
590
589
591
590
const float d = float(data_a[ib].d);
592
- const float db = d * 0.25 * (0.5 + scale);
593
- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
594
- const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
595
- const vec2 v = db * vec2(sign01) * vec2(unpack8(grid));
596
-
597
- buf_a[buf_idx ] = BUF_TYPE(v.x);
598
- buf_a[buf_idx + 1] = BUF_TYPE(v.y);
591
+ const BUF_TYPE db = BUF_TYPE(d * 0.25 * (0.5 + scale));
592
+ const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)];
593
+ const vec4 grid0 = vec4(unpack8(grid.x));
594
+ const vec4 grid1 = vec4(unpack8(grid.y));
595
+
596
+ buf_a[buf_idx ] = db * BUF_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
597
+ buf_a[buf_idx + 1] = db * BUF_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
598
+ buf_a[buf_idx + 2] = db * BUF_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
599
+ buf_a[buf_idx + 3] = db * BUF_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
600
+ buf_a[buf_idx + 4] = db * BUF_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
601
+ buf_a[buf_idx + 5] = db * BUF_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
602
+ buf_a[buf_idx + 6] = db * BUF_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
603
+ buf_a[buf_idx + 7] = db * BUF_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
599
604
#elif defined(DATA_A_IQ3_XXS)
600
605
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
601
606
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
602
607
603
- const uint ib = idx / 128 ; // 2 values per idx
604
- const uint iqs = ( idx % 128) / 2 ; // 0..63
608
+ const uint ib = idx / 64 ; // 4 values per idx
609
+ const uint iqs = idx % 64 ; // 0..63
605
610
const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
606
611
607
612
const float d = float(data_a[ib].d);
@@ -614,33 +619,35 @@ void main() {
614
619
));
615
620
const float db = d * 0.5 * (0.5 + (signs >> 28));
616
621
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
617
- const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
618
- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
619
- const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
620
- const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
621
-
622
- buf_a[buf_idx ] = BUF_TYPE(v.x);
623
- buf_a[buf_idx + 1] = BUF_TYPE(v.y);
622
+ const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2));
623
+ const uint grid = iq3xxs_grid[qs];
624
+ const vec4 v = db * vec4(unpack8(grid));
625
+
626
+ buf_a[buf_idx ] = BUF_TYPE((sign & 1) != 0 ? -v.x : v.x);
627
+ buf_a[buf_idx + 1] = BUF_TYPE((sign & 2) != 0 ? -v.y : v.y);
628
+ buf_a[buf_idx + 2] = BUF_TYPE((sign & 4) != 0 ? -v.z : v.z);
629
+ buf_a[buf_idx + 3] = BUF_TYPE((sign & 8) != 0 ? -v.w : v.w);
624
630
#elif defined(DATA_A_IQ3_S)
625
631
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
626
632
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
627
633
628
- const uint ib = idx / 128 ; // 2 values per idx
629
- const uint iqs = ( idx % 128) / 2 ; // 0..63
634
+ const uint ib = idx / 64 ; // 4 values per idx
635
+ const uint iqs = idx % 64 ; // 0..63
630
636
const uint iqh = iqs / 8;
631
637
632
638
const float d = float(data_a[ib].d);
633
639
const uint qs = data_a[ib].qs[iqs];
634
640
const uint qh = data_a[ib].qh[iqh];
635
- const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4 )));
641
+ const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2 )));
636
642
const uint scale = data_a[ib].scales[iqs / 16];
637
- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
638
643
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
639
- const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2)) ;
640
- const vec2 v = db * vec2(sign01) * vec2( unpack8(grid).xy );
644
+ const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
645
+ const vec4 v = db * vec4( unpack8(grid));
641
646
642
- buf_a[buf_idx ] = BUF_TYPE(v.x);
643
- buf_a[buf_idx + 1] = BUF_TYPE(v.y);
647
+ buf_a[buf_idx ] = BUF_TYPE((sign & 1) != 0 ? -v.x : v.x);
648
+ buf_a[buf_idx + 1] = BUF_TYPE((sign & 2) != 0 ? -v.y : v.y);
649
+ buf_a[buf_idx + 2] = BUF_TYPE((sign & 4) != 0 ? -v.z : v.z);
650
+ buf_a[buf_idx + 3] = BUF_TYPE((sign & 8) != 0 ? -v.w : v.w);
644
651
#elif defined(DATA_A_IQ4_XS)
645
652
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
646
653
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
0 commit comments