@@ -606,17 +606,29 @@ class tinyBLAS_Q0_AVX {
606
606
case 0x44 :
607
607
mc = 4 ;
608
608
nc = 4 ;
609
+ #if defined(__AVX2__) && defined(__F16C__)
610
+ gemm4xN<4 >(m0, m, n0, n);
611
+ #else
609
612
gemm<4 , 4 >(m0, m, n0, n);
613
+ #endif
610
614
break ;
611
615
case 0x43 :
612
616
mc = 4 ;
613
617
nc = 3 ;
618
+ #if defined(__AVX2__) && defined(__F16C__)
619
+ gemm4xN<3 >(m0, m, n0, n);
620
+ #else
614
621
gemm<4 , 3 >(m0, m, n0, n);
622
+ #endif
615
623
break ;
616
624
case 0x34 :
617
625
mc = 3 ;
618
626
nc = 4 ;
627
+ #if defined(__AVX2__) && defined(__F16C__)
628
+ gemmMx4<3 >(m0, m, n0, n);
629
+ #else
619
630
gemm<3 , 4 >(m0, m, n0, n);
631
+ #endif
620
632
break ;
621
633
case 0x33 :
622
634
mc = 3 ;
@@ -626,26 +638,42 @@ class tinyBLAS_Q0_AVX {
626
638
case 0x42 :
627
639
mc = 4 ;
628
640
nc = 2 ;
641
+ #if defined(__AVX2__) && defined(__F16C__)
642
+ gemm4xN<2 >(m0, m, n0, n);
643
+ #else
629
644
gemm<4 , 2 >(m0, m, n0, n);
645
+ #endif
630
646
break ;
631
647
case 0x24 :
632
648
mc = 2 ;
633
649
nc = 4 ;
650
+ #if defined(__AVX2__) && defined(__F16C__)
651
+ gemmMx4<2 >(m0, m, n0, n);
652
+ #else
634
653
gemm<2 , 4 >(m0, m, n0, n);
654
+ #endif
635
655
break ;
636
656
#else
637
657
case 0x44 :
638
658
case 0x43 :
639
659
case 0x42 :
640
660
mc = 4 ;
641
661
nc = 2 ;
662
+ #if defined(__AVX2__) && defined(__F16C__)
663
+ gemm4xN<2 >(m0, m, n0, n);
664
+ #else
642
665
gemm<4 , 2 >(m0, m, n0, n);
666
+ #endif
643
667
break ;
644
668
case 0x34 :
645
669
case 0x24 :
646
670
mc = 2 ;
647
671
nc = 4 ;
672
+ #if defined(__AVX2__) && defined(__F16C__)
673
+ gemmMx4<2 >(m0, m, n0, n);
674
+ #else
648
675
gemm<2 , 4 >(m0, m, n0, n);
676
+ #endif
649
677
break ;
650
678
case 0x33 :
651
679
#endif
@@ -662,7 +690,11 @@ class tinyBLAS_Q0_AVX {
662
690
case 0x41 :
663
691
mc = 4 ;
664
692
nc = 1 ;
693
+ #if defined(__AVX2__) && defined(__F16C__)
694
+ gemm4xN<1 >(m0, m, n0, n);
695
+ #else
665
696
gemm<4 , 1 >(m0, m, n0, n);
697
+ #endif
666
698
break ;
667
699
case 0x22 :
668
700
mc = 2 ;
@@ -672,7 +704,11 @@ class tinyBLAS_Q0_AVX {
672
704
case 0x14 :
673
705
mc = 1 ;
674
706
nc = 4 ;
707
+ #if defined(__AVX2__) && defined(__F16C__)
708
+ gemmMx4<1 >(m0, m, n0, n);
709
+ #else
675
710
gemm<1 , 4 >(m0, m, n0, n);
711
+ #endif
676
712
break ;
677
713
case 0x31 :
678
714
mc = 3 ;
@@ -708,6 +744,119 @@ class tinyBLAS_Q0_AVX {
708
744
mnpack (m0, m, np, n);
709
745
}
710
746
747
+ #if defined(__AVX2__) && defined(__F16C__)
748
+ // Templated functions for gemm of dimensions 4xN
749
+ template <int RN>
750
+ NOINLINE void gemm4xN (int64_t m0, int64_t m, int64_t n0, int64_t n) {
751
+ int64_t ytiles = (m - m0) / 4 ;
752
+ int64_t xtiles = (n - n0) / RN;
753
+ int64_t tiles = xtiles * ytiles;
754
+ int64_t duty = (tiles + nth - 1 ) / nth;
755
+ int64_t start = duty * ith;
756
+ int64_t end = start + duty;
757
+ if (end > tiles)
758
+ end = tiles;
759
+ for (int64_t job = start; job < end; ++job) {
760
+ int64_t ii = m0 + job / xtiles * 4 ;
761
+ int64_t jj = n0 + job % xtiles * RN;
762
+ __m256 Cv[RN][4 ] = {};
763
+ for (int64_t l = 0 ; l < k; ++l) {
764
+ uint64_t a_delta = ((uint64_t )A[lda * (ii + 3 ) + l].d << 48 ) | ((uint64_t )A[lda * (ii + 2 ) + l].d << 32 ) | ((uint64_t )A[lda * (ii + 1 ) + l].d << 16 ) | (A[lda * (ii + 0 ) + l].d );
765
+ // Convert delta values for four blocks to float values
766
+ __m128 da = _mm_cvtph_ps (_mm_set_epi64x (0 , a_delta));
767
+ __m256i avec0 = load (A + lda * (ii + 0 ) + l);
768
+ __m256i avec1 = load (A + lda * (ii + 1 ) + l);
769
+ __m256i avec2 = load (A + lda * (ii + 2 ) + l);
770
+ __m256i avec3 = load (A + lda * (ii + 3 ) + l);
771
+ for (int64_t j = 0 ; j < RN; ++j) {
772
+ __m128 db = _mm_set1_ps (unhalf (B[ldb * (jj + j) + l].d ));
773
+ // Computation of product of delta values for four blocks and replicate it across 256 bit lane
774
+ __m256 dvec = _mm256_castps128_ps256 (_mm_mul_ps (da, db));
775
+ dvec = _mm256_permute2f128_ps (dvec ,dvec, 0 );
776
+ // Computation of dot product and multiplication with appropriate delta value products
777
+ Cv[j][0 ] = madd (_mm256_shuffle_ps (dvec, dvec, 0 ),
778
+ updot (_mm256_sign_epi8 (avec0, avec0),
779
+ _mm256_sign_epi8 (load (B + ldb * (jj + j) + l), avec0)),
780
+ Cv[j][0 ]);
781
+ Cv[j][1 ] = madd (_mm256_shuffle_ps (dvec, dvec, 85 ),
782
+ updot (_mm256_sign_epi8 (avec1, avec1),
783
+ _mm256_sign_epi8 (load (B + ldb * (jj + j) + l), avec1)),
784
+ Cv[j][1 ]);
785
+ Cv[j][2 ] = madd (_mm256_shuffle_ps (dvec, dvec, 170 ),
786
+ updot (_mm256_sign_epi8 (avec2, avec2),
787
+ _mm256_sign_epi8 (load (B + ldb * (jj + j) + l), avec2)),
788
+ Cv[j][2 ]);
789
+ Cv[j][3 ] = madd (_mm256_shuffle_ps (dvec, dvec, 255 ),
790
+ updot (_mm256_sign_epi8 (avec3, avec3),
791
+ _mm256_sign_epi8 (load (B + ldb * (jj + j) + l), avec3)),
792
+ Cv[j][3 ]);
793
+ }
794
+ }
795
+
796
+ for (int64_t j = 0 ; j < RN; ++j)
797
+ for (int64_t i = 0 ; i < 4 ; ++i)
798
+ C[ldc * (jj + j) + (ii + i)] = hsum (Cv[j][i]);
799
+ }
800
+ }
801
+
802
+ // Templated functions for gemm of dimensions Mx4
803
+ template <int RM>
804
+ NOINLINE void gemmMx4 (int64_t m0, int64_t m, int64_t n0, int64_t n) {
805
+ int64_t ytiles = (m - m0) / RM;
806
+ int64_t xtiles = (n - n0) / 4 ;
807
+ int64_t tiles = xtiles * ytiles;
808
+ int64_t duty = (tiles + nth - 1 ) / nth;
809
+ int64_t start = duty * ith;
810
+ int64_t end = start + duty;
811
+ if (end > tiles)
812
+ end = tiles;
813
+ for (int64_t job = start; job < end; ++job) {
814
+ int64_t ii = m0 + job / xtiles * RM;
815
+ int64_t jj = n0 + job % xtiles * 4 ;
816
+ __m256 Cv[4 ][RM] = {};
817
+ for (int64_t l = 0 ; l < k; ++l) {
818
+ uint64_t b_delta = ((uint64_t )B[ldb * (jj + 3 ) + l].d << 48 ) | ((uint64_t )B[ldb * (jj + 2 ) + l].d << 32 ) | ((uint64_t )B[ldb * (jj + 1 ) + l].d << 16 ) | (B[ldb * (jj + 0 ) + l].d );
819
+ // Convert delta values for four blocks to float values
820
+ __m128 db = _mm_cvtph_ps (_mm_set_epi64x (0 , b_delta));
821
+ __m256i bvec0 = load (B + ldb * (jj + 0 ) + l);
822
+ __m256i bvec1 = load (B + ldb * (jj + 1 ) + l);
823
+ __m256i bvec2 = load (B + ldb * (jj + 2 ) + l);
824
+ __m256i bvec3 = load (B + ldb * (jj + 3 ) + l);
825
+ for (int64_t i = 0 ; i < RM; ++i) {
826
+ __m128 da = _mm_set1_ps (unhalf ((A[lda * (ii + i) + l].d )));
827
+ // Computation of product of delta values for four blocks and replicate it across 256 bit lane
828
+ __m256 dvec = _mm256_castps128_ps256 (_mm_mul_ps (da, db));
829
+ dvec = _mm256_permute2f128_ps (dvec ,dvec, 0 );
830
+ // Computation of dot product and multiplication with appropriate delta value products
831
+ Cv[0 ][i] = madd (_mm256_shuffle_ps (dvec, dvec, 0 ),
832
+ updot (_mm256_sign_epi8 (load (A + lda * (ii + i) + l),
833
+ load (A + lda * (ii + i) + l)),
834
+ _mm256_sign_epi8 (bvec0, load (A + lda * (ii + i) + l))),
835
+ Cv[0 ][i]);
836
+ Cv[1 ][i] = madd (_mm256_shuffle_ps (dvec, dvec, 85 ),
837
+ updot (_mm256_sign_epi8 (load (A + lda * (ii + i) + l),
838
+ load (A + lda * (ii + i) + l)),
839
+ _mm256_sign_epi8 (bvec1, load (A + lda * (ii + i) + l))),
840
+ Cv[1 ][i]);
841
+ Cv[2 ][i] = madd (_mm256_shuffle_ps (dvec, dvec, 170 ),
842
+ updot (_mm256_sign_epi8 (load (A + lda * (ii + i) + l),
843
+ load (A + lda * (ii + i) + l)),
844
+ _mm256_sign_epi8 (bvec2, load (A + lda * (ii + i) + l))),
845
+ Cv[2 ][i]);
846
+ Cv[3 ][i] = madd (_mm256_shuffle_ps (dvec, dvec, 255 ),
847
+ updot (_mm256_sign_epi8 (load (A + lda * (ii + i) + l),
848
+ load (A + lda * (ii + i) + l)),
849
+ _mm256_sign_epi8 (bvec3, load (A + lda * (ii + i) + l))),
850
+ Cv[3 ][i]);
851
+ }
852
+ }
853
+ for (int64_t j = 0 ; j < 4 ; ++j)
854
+ for (int64_t i = 0 ; i < RM; ++i)
855
+ C[ldc * (jj + j) + (ii + i)] = hsum (Cv[j][i]);
856
+ }
857
+ }
858
+ #endif
859
+
711
860
template <int RM, int RN>
712
861
NOINLINE void gemm (int64_t m0, int64_t m, int64_t n0, int64_t n) {
713
862
int64_t ytiles = (m - m0) / RM;
0 commit comments