Skip to content

Commit ea5d747

Browse files
authored
sgemm : improved Q4_0 and Q8_0 performance via 4xN and Mx4 gemm (#8908)
1 parent 49271ef commit ea5d747

File tree

1 file changed

+149
-0
lines changed

1 file changed

+149
-0
lines changed

ggml/src/llamafile/sgemm.cpp

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,17 +606,29 @@ class tinyBLAS_Q0_AVX {
606606
case 0x44:
607607
mc = 4;
608608
nc = 4;
609+
#if defined(__AVX2__) && defined(__F16C__)
610+
gemm4xN<4>(m0, m, n0, n);
611+
#else
609612
gemm<4, 4>(m0, m, n0, n);
613+
#endif
610614
break;
611615
case 0x43:
612616
mc = 4;
613617
nc = 3;
618+
#if defined(__AVX2__) && defined(__F16C__)
619+
gemm4xN<3>(m0, m, n0, n);
620+
#else
614621
gemm<4, 3>(m0, m, n0, n);
622+
#endif
615623
break;
616624
case 0x34:
617625
mc = 3;
618626
nc = 4;
627+
#if defined(__AVX2__) && defined(__F16C__)
628+
gemmMx4<3>(m0, m, n0, n);
629+
#else
619630
gemm<3, 4>(m0, m, n0, n);
631+
#endif
620632
break;
621633
case 0x33:
622634
mc = 3;
@@ -626,26 +638,42 @@ class tinyBLAS_Q0_AVX {
626638
case 0x42:
627639
mc = 4;
628640
nc = 2;
641+
#if defined(__AVX2__) && defined(__F16C__)
642+
gemm4xN<2>(m0, m, n0, n);
643+
#else
629644
gemm<4, 2>(m0, m, n0, n);
645+
#endif
630646
break;
631647
case 0x24:
632648
mc = 2;
633649
nc = 4;
650+
#if defined(__AVX2__) && defined(__F16C__)
651+
gemmMx4<2>(m0, m, n0, n);
652+
#else
634653
gemm<2, 4>(m0, m, n0, n);
654+
#endif
635655
break;
636656
#else
637657
case 0x44:
638658
case 0x43:
639659
case 0x42:
640660
mc = 4;
641661
nc = 2;
662+
#if defined(__AVX2__) && defined(__F16C__)
663+
gemm4xN<2>(m0, m, n0, n);
664+
#else
642665
gemm<4, 2>(m0, m, n0, n);
666+
#endif
643667
break;
644668
case 0x34:
645669
case 0x24:
646670
mc = 2;
647671
nc = 4;
672+
#if defined(__AVX2__) && defined(__F16C__)
673+
gemmMx4<2>(m0, m, n0, n);
674+
#else
648675
gemm<2, 4>(m0, m, n0, n);
676+
#endif
649677
break;
650678
case 0x33:
651679
#endif
@@ -662,7 +690,11 @@ class tinyBLAS_Q0_AVX {
662690
case 0x41:
663691
mc = 4;
664692
nc = 1;
693+
#if defined(__AVX2__) && defined(__F16C__)
694+
gemm4xN<1>(m0, m, n0, n);
695+
#else
665696
gemm<4, 1>(m0, m, n0, n);
697+
#endif
666698
break;
667699
case 0x22:
668700
mc = 2;
@@ -672,7 +704,11 @@ class tinyBLAS_Q0_AVX {
672704
case 0x14:
673705
mc = 1;
674706
nc = 4;
707+
#if defined(__AVX2__) && defined(__F16C__)
708+
gemmMx4<1>(m0, m, n0, n);
709+
#else
675710
gemm<1, 4>(m0, m, n0, n);
711+
#endif
676712
break;
677713
case 0x31:
678714
mc = 3;
@@ -708,6 +744,119 @@ class tinyBLAS_Q0_AVX {
708744
mnpack(m0, m, np, n);
709745
}
710746

747+
#if defined(__AVX2__) && defined(__F16C__)
748+
// Templated functions for gemm of dimensions 4xN
749+
template <int RN>
750+
NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) {
751+
int64_t ytiles = (m - m0) / 4;
752+
int64_t xtiles = (n - n0) / RN;
753+
int64_t tiles = xtiles * ytiles;
754+
int64_t duty = (tiles + nth - 1) / nth;
755+
int64_t start = duty * ith;
756+
int64_t end = start + duty;
757+
if (end > tiles)
758+
end = tiles;
759+
for (int64_t job = start; job < end; ++job) {
760+
int64_t ii = m0 + job / xtiles * 4;
761+
int64_t jj = n0 + job % xtiles * RN;
762+
__m256 Cv[RN][4] = {};
763+
for (int64_t l = 0; l < k; ++l) {
764+
uint64_t a_delta = ((uint64_t)A[lda * (ii + 3) + l].d << 48) | ((uint64_t)A[lda * (ii + 2) + l].d << 32) | ((uint64_t)A[lda * (ii + 1) + l].d << 16) | (A[lda * (ii + 0) + l].d);
765+
// Convert delta values for four blocks to float values
766+
__m128 da = _mm_cvtph_ps(_mm_set_epi64x(0, a_delta));
767+
__m256i avec0 = load(A + lda * (ii + 0) + l);
768+
__m256i avec1 = load(A + lda * (ii + 1) + l);
769+
__m256i avec2 = load(A + lda * (ii + 2) + l);
770+
__m256i avec3 = load(A + lda * (ii + 3) + l);
771+
for (int64_t j = 0; j < RN; ++j) {
772+
__m128 db = _mm_set1_ps(unhalf(B[ldb * (jj + j) + l].d));
773+
// Computation of product of delta values for four blocks and replicate it across 256 bit lane
774+
__m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db));
775+
dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
776+
// Computation of dot product and multiplication with appropriate delta value products
777+
Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
778+
updot(_mm256_sign_epi8(avec0, avec0),
779+
_mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec0)),
780+
Cv[j][0]);
781+
Cv[j][1] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
782+
updot(_mm256_sign_epi8(avec1, avec1),
783+
_mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec1)),
784+
Cv[j][1]);
785+
Cv[j][2] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
786+
updot(_mm256_sign_epi8(avec2, avec2),
787+
_mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec2)),
788+
Cv[j][2]);
789+
Cv[j][3] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
790+
updot(_mm256_sign_epi8(avec3, avec3),
791+
_mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec3)),
792+
Cv[j][3]);
793+
}
794+
}
795+
796+
for (int64_t j = 0; j < RN; ++j)
797+
for (int64_t i = 0; i < 4; ++i)
798+
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
799+
}
800+
}
801+
802+
// Templated functions for gemm of dimensions Mx4
803+
template <int RM>
804+
NOINLINE void gemmMx4(int64_t m0, int64_t m, int64_t n0, int64_t n) {
805+
int64_t ytiles = (m - m0) / RM;
806+
int64_t xtiles = (n - n0) / 4;
807+
int64_t tiles = xtiles * ytiles;
808+
int64_t duty = (tiles + nth - 1) / nth;
809+
int64_t start = duty * ith;
810+
int64_t end = start + duty;
811+
if (end > tiles)
812+
end = tiles;
813+
for (int64_t job = start; job < end; ++job) {
814+
int64_t ii = m0 + job / xtiles * RM;
815+
int64_t jj = n0 + job % xtiles * 4;
816+
__m256 Cv[4][RM] = {};
817+
for (int64_t l = 0; l < k; ++l) {
818+
uint64_t b_delta = ((uint64_t)B[ldb * (jj + 3) + l].d << 48) | ((uint64_t)B[ldb * (jj + 2) + l].d << 32) | ((uint64_t)B[ldb * (jj + 1) + l].d << 16) | (B[ldb * (jj + 0) + l].d);
819+
// Convert delta values for four blocks to float values
820+
__m128 db = _mm_cvtph_ps(_mm_set_epi64x(0, b_delta));
821+
__m256i bvec0 = load(B + ldb * (jj + 0) + l);
822+
__m256i bvec1 = load(B + ldb * (jj + 1) + l);
823+
__m256i bvec2 = load(B + ldb * (jj + 2) + l);
824+
__m256i bvec3 = load(B + ldb * (jj + 3) + l);
825+
for (int64_t i = 0; i < RM; ++i) {
826+
__m128 da = _mm_set1_ps(unhalf((A[lda * (ii + i) + l].d)));
827+
// Computation of product of delta values for four blocks and replicate it across 256 bit lane
828+
__m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db));
829+
dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
830+
// Computation of dot product and multiplication with appropriate delta value products
831+
Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
832+
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
833+
load(A + lda * (ii + i) + l)),
834+
_mm256_sign_epi8(bvec0, load(A + lda * (ii + i) + l))),
835+
Cv[0][i]);
836+
Cv[1][i] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
837+
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
838+
load(A + lda * (ii + i) + l)),
839+
_mm256_sign_epi8(bvec1, load(A + lda * (ii + i) + l))),
840+
Cv[1][i]);
841+
Cv[2][i] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
842+
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
843+
load(A + lda * (ii + i) + l)),
844+
_mm256_sign_epi8(bvec2, load(A + lda * (ii + i) + l))),
845+
Cv[2][i]);
846+
Cv[3][i] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
847+
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
848+
load(A + lda * (ii + i) + l)),
849+
_mm256_sign_epi8(bvec3, load(A + lda * (ii + i) + l))),
850+
Cv[3][i]);
851+
}
852+
}
853+
for (int64_t j = 0; j < 4; ++j)
854+
for (int64_t i = 0; i < RM; ++i)
855+
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
856+
}
857+
}
858+
#endif
859+
711860
template <int RM, int RN>
712861
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
713862
int64_t ytiles = (m - m0) / RM;

0 commit comments

Comments
 (0)