@@ -1753,7 +1753,6 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
1753
1753
}
1754
1754
1755
1755
// contiguous u/y values
1756
- // also used for q5_K
1757
1756
static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq (
1758
1757
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
1759
1758
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
@@ -1763,19 +1762,18 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
1763
1762
float sumf_m = 0 .0f ;
1764
1763
1765
1764
#pragma unroll
1766
- for (int i0 = 0 ; i0 < VDR_Q4_K_Q8_1_MMQ; i0 += (QI8_1/QR4_K) ) {
1765
+ for (int i = 0 ; i < QR4_K* VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i ) {
1767
1766
int sumi_d = 0 ;
1768
1767
1769
1768
#pragma unroll
1770
- for (int i = i0; i < i0 + (QI8_1/QR4_K); ++i) {
1771
- sumi_d = __dp4a (v[2 *i+0 ], u[2 *i+0 ], sumi_d); // SIMD dot product
1772
- sumi_d = __dp4a (v[2 *i+1 ], u[2 *i+1 ], sumi_d); // SIMD dot product
1769
+ for (int j = 0 ; j < QI8_1; ++j) {
1770
+ sumi_d = __dp4a ((v[j] >> (4 *i)) & 0x0F0F0F0F , u[i*QI8_1 + j], sumi_d); // SIMD dot product
1773
1771
}
1774
1772
1775
- const float2 ds8f = __half22float2 (ds8[i0 / 4 ]);
1773
+ const float2 ds8f = __half22float2 (ds8[i ]);
1776
1774
1777
- sumf_d += ds8f.x * (sc[i0/ 4 ] * sumi_d);
1778
- sumf_m += ds8f.y * m[i0/ 4 ]; // sum of q8_1 block * q4_K min val
1775
+ sumf_d += ds8f.x * (sc[i ] * sumi_d);
1776
+ sumf_m += ds8f.y * m[i ]; // sum of q8_1 block * q4_K min val
1779
1777
}
1780
1778
1781
1779
const float2 dm4f = __half22float2 (dm4);
@@ -1792,7 +1790,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
1792
1790
#define VDR_Q5_K_Q8_1_MMQ 8
1793
1791
1794
1792
// contiguous v/x values
1795
- static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl (
1793
+ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq (
1796
1794
const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc,
1797
1795
const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) {
1798
1796
@@ -1829,6 +1827,40 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl(
1829
1827
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
1830
1828
}
1831
1829
1830
+ // contiguous u/y values
1831
+ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq (
1832
+ const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
1833
+ const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
1834
+
1835
+ #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
1836
+ float sumf_d = 0 .0f ;
1837
+ float sumf_m = 0 .0f ;
1838
+
1839
+ #pragma unroll
1840
+ for (int i = 0 ; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) {
1841
+ int sumi_d = 0 ;
1842
+
1843
+ #pragma unroll
1844
+ for (int j = 0 ; j < QI8_1; ++j) {
1845
+ sumi_d = __dp4a (v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product
1846
+ }
1847
+
1848
+ const float2 ds8f = __half22float2 (ds8[i]);
1849
+
1850
+ sumf_d += ds8f.x * (sc[i] * sumi_d);
1851
+ sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val
1852
+ }
1853
+
1854
+ const float2 dm4f = __half22float2 (dm4);
1855
+
1856
+ return dm4f.x *sumf_d - dm4f.y *sumf_m;
1857
+
1858
+ #else
1859
+ assert (false );
1860
+ return 0 .0f ; // only to satisfy the compiler
1861
+ #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
1862
+ }
1863
+
1832
1864
#define VDR_Q6_K_Q8_1_MMVQ 1
1833
1865
#define VDR_Q6_K_Q8_1_MMQ 8
1834
1866
@@ -2824,18 +2856,11 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat(
2824
2856
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
2825
2857
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
2826
2858
2827
- int v[QR4_K*VDR_Q4_K_Q8_1_MMQ];
2828
-
2829
- #pragma unroll
2830
- for (int l = 0 ; l < VDR_Q4_K_Q8_1_MMQ; ++l) {
2831
- v[l + 0 ] = (x_ql[i * (WARP_SIZE + 1 ) + k + l] >> 0 ) & 0x0F0F0F0F ;
2832
- v[l + (QI4_K/4 )] = (x_ql[i * (WARP_SIZE + 1 ) + k + l] >> 4 ) & 0x0F0F0F0F ;
2833
- }
2834
-
2835
2859
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8 ) + i/8 + k/16 ]) + 2 *((k % 16 ) / 8 );
2836
2860
2837
2861
const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE;
2838
- return vec_dot_q4_K_q8_1_impl_mmq (v, &y_qs[index_y], sc, sc+8 , x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
2862
+ return vec_dot_q4_K_q8_1_impl_mmq (&x_ql[i * (WARP_SIZE + 1 ) + k], &y_qs[index_y], sc, sc+8 ,
2863
+ x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
2839
2864
}
2840
2865
2841
2866
static __device__ __forceinline__ float vec_dot_q5_K_q8_1 (
@@ -2882,7 +2907,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
2882
2907
u[2 *i+1 ] = q8[4 ];
2883
2908
}
2884
2909
2885
- return vec_dot_q5_K_q8_1_impl (vl, vh, u, sc, m, bq5_K->dm , d8);
2910
+ return vec_dot_q5_K_q8_1_impl_vmmq (vl, vh, u, sc, m, bq5_K->dm , d8);
2886
2911
2887
2912
#else
2888
2913
@@ -3025,7 +3050,8 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat(
3025
3050
3026
3051
const int index_x = i * (QR5_K*WARP_SIZE + 1 ) + QR5_K*k;
3027
3052
const int index_y = j * WARP_SIZE + (QR5_K*k) % WARP_SIZE;
3028
- return vec_dot_q4_K_q8_1_impl_mmq (&x_ql[index_x], &y_qs[index_y], sc, sc+8 , x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
3053
+ return vec_dot_q5_K_q8_1_impl_mmq (&x_ql[index_x], &y_qs[index_y], sc, sc+8 ,
3054
+ x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
3029
3055
}
3030
3056
3031
3057
static __device__ __forceinline__ float vec_dot_q6_K_q8_1 (
@@ -3301,7 +3327,11 @@ template <bool need_check> static __global__ void mul_mat_q4_0(
3301
3327
#define MMQ_Y_Q4_1_PASCAL 64
3302
3328
#define NWARPS_Q4_1_PASCAL 8
3303
3329
3304
- template <bool need_check> static __global__ void mul_mat_q4_1 (
3330
+ template <bool need_check> static __global__ void
3331
+ #if __CUDA_ARCH__ < CC_TURING
3332
+ __launch_bounds__ (WARP_SIZE*NWARPS_Q4_1_PASCAL, 2 )
3333
+ #endif // __CUDA_ARCH__ < CC_TURING
3334
+ mul_mat_q4_1 (
3305
3335
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
3306
3336
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
3307
3337
@@ -3471,7 +3501,11 @@ template <bool need_check> static __global__ void mul_mat_q2_K(
3471
3501
#define MMQ_Y_Q3_K_PASCAL 64
3472
3502
#define NWARPS_Q3_K_PASCAL 8
3473
3503
3474
- template <bool need_check> static __global__ void mul_mat_q3_K (
3504
+ template <bool need_check> static __global__ void
3505
+ #if __CUDA_ARCH__ < CC_TURING
3506
+ __launch_bounds__ (WARP_SIZE*NWARPS_Q3_K_PASCAL, 2 )
3507
+ #endif // __CUDA_ARCH__ < CC_TURING
3508
+ mul_mat_q3_K (
3475
3509
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
3476
3510
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
3477
3511
@@ -3501,11 +3535,15 @@ template <bool need_check> static __global__ void mul_mat_q3_K(
3501
3535
#define MMQ_X_Q4_K_AMPERE 64
3502
3536
#define MMQ_Y_Q4_K_AMPERE 128
3503
3537
#define NWARPS_Q4_K_AMPERE 4
3504
- #define MMQ_X_Q4_K_PASCAL 32
3538
+ #define MMQ_X_Q4_K_PASCAL 64
3505
3539
#define MMQ_Y_Q4_K_PASCAL 64
3506
3540
#define NWARPS_Q4_K_PASCAL 8
3507
3541
3508
- template <bool need_check> static __global__ void mul_mat_q4_K (
3542
+ template <bool need_check> static __global__ void
3543
+ #if __CUDA_ARCH__ < CC_TURING
3544
+ __launch_bounds__ (WARP_SIZE*NWARPS_Q4_K_PASCAL, 2 )
3545
+ #endif // __CUDA_ARCH__ < CC_TURING
3546
+ mul_mat_q4_K (
3509
3547
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
3510
3548
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
3511
3549
@@ -3569,11 +3607,15 @@ template <bool need_check> static __global__ void mul_mat_q5_K(
3569
3607
#define MMQ_X_Q6_K_AMPERE 64
3570
3608
#define MMQ_Y_Q6_K_AMPERE 64
3571
3609
#define NWARPS_Q6_K_AMPERE 4
3572
- #define MMQ_X_Q6_K_PASCAL 32
3610
+ #define MMQ_X_Q6_K_PASCAL 64
3573
3611
#define MMQ_Y_Q6_K_PASCAL 64
3574
3612
#define NWARPS_Q6_K_PASCAL 8
3575
3613
3576
- template <bool need_check> static __global__ void mul_mat_q6_K (
3614
+ template <bool need_check> static __global__ void
3615
+ #if __CUDA_ARCH__ < CC_TURING
3616
+ __launch_bounds__ (WARP_SIZE*NWARPS_Q6_K_PASCAL, 2 )
3617
+ #endif // __CUDA_ARCH__ < CC_TURING
3618
+ mul_mat_q6_K (
3577
3619
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
3578
3620
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
3579
3621
0 commit comments