Skip to content

Commit 1cd06fa

Browse files
CUDA: launch_bounds, small q4_K, q5_K mmq refactor (#2596)
1 parent 2feb893 commit 1cd06fa

File tree

1 file changed

+68
-26
lines changed

1 file changed

+68
-26
lines changed

ggml-cuda.cu

Lines changed: 68 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1753,7 +1753,6 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
17531753
}
17541754

17551755
// contiguous u/y values
1756-
// also used for q5_K
17571756
static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
17581757
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
17591758
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
@@ -1763,19 +1762,18 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
17631762
float sumf_m = 0.0f;
17641763

17651764
#pragma unroll
1766-
for (int i0 = 0; i0 < VDR_Q4_K_Q8_1_MMQ; i0 += (QI8_1/QR4_K)) {
1765+
for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) {
17671766
int sumi_d = 0;
17681767

17691768
#pragma unroll
1770-
for (int i = i0; i < i0 + (QI8_1/QR4_K); ++i) {
1771-
sumi_d = __dp4a(v[2*i+0], u[2*i+0], sumi_d); // SIMD dot product
1772-
sumi_d = __dp4a(v[2*i+1], u[2*i+1], sumi_d); // SIMD dot product
1769+
for (int j = 0; j < QI8_1; ++j) {
1770+
sumi_d = __dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product
17731771
}
17741772

1775-
const float2 ds8f = __half22float2(ds8[i0 / 4]);
1773+
const float2 ds8f = __half22float2(ds8[i]);
17761774

1777-
sumf_d += ds8f.x * (sc[i0/4] * sumi_d);
1778-
sumf_m += ds8f.y * m[i0/4]; // sum of q8_1 block * q4_K min val
1775+
sumf_d += ds8f.x * (sc[i] * sumi_d);
1776+
sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val
17791777
}
17801778

17811779
const float2 dm4f = __half22float2(dm4);
@@ -1792,7 +1790,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
17921790
#define VDR_Q5_K_Q8_1_MMQ 8
17931791

17941792
// contiguous v/x values
1795-
static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl(
1793+
static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
17961794
const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc,
17971795
const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) {
17981796

@@ -1829,6 +1827,40 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl(
18291827
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
18301828
}
18311829

1830+
// contiguous u/y values
1831+
static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
1832+
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
1833+
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
1834+
1835+
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
1836+
float sumf_d = 0.0f;
1837+
float sumf_m = 0.0f;
1838+
1839+
#pragma unroll
1840+
for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) {
1841+
int sumi_d = 0;
1842+
1843+
#pragma unroll
1844+
for (int j = 0; j < QI8_1; ++j) {
1845+
sumi_d = __dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product
1846+
}
1847+
1848+
const float2 ds8f = __half22float2(ds8[i]);
1849+
1850+
sumf_d += ds8f.x * (sc[i] * sumi_d);
1851+
sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val
1852+
}
1853+
1854+
const float2 dm4f = __half22float2(dm4);
1855+
1856+
return dm4f.x*sumf_d - dm4f.y*sumf_m;
1857+
1858+
#else
1859+
assert(false);
1860+
return 0.0f; // only to satisfy the compiler
1861+
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
1862+
}
1863+
18321864
#define VDR_Q6_K_Q8_1_MMVQ 1
18331865
#define VDR_Q6_K_Q8_1_MMQ 8
18341866

@@ -2824,18 +2856,11 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat(
28242856
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
28252857
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
28262858

2827-
int v[QR4_K*VDR_Q4_K_Q8_1_MMQ];
2828-
2829-
#pragma unroll
2830-
for (int l = 0; l < VDR_Q4_K_Q8_1_MMQ; ++l) {
2831-
v[l + 0] = (x_ql[i * (WARP_SIZE + 1) + k + l] >> 0) & 0x0F0F0F0F;
2832-
v[l + (QI4_K/4)] = (x_ql[i * (WARP_SIZE + 1) + k + l] >> 4) & 0x0F0F0F0F;
2833-
}
2834-
28352859
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8);
28362860

28372861
const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE;
2838-
return vec_dot_q4_K_q8_1_impl_mmq(v, &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
2862+
return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8,
2863+
x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
28392864
}
28402865

28412866
static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
@@ -2882,7 +2907,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
28822907
u[2*i+1] = q8[4];
28832908
}
28842909

2885-
return vec_dot_q5_K_q8_1_impl(vl, vh, u, sc, m, bq5_K->dm, d8);
2910+
return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8);
28862911

28872912
#else
28882913

@@ -3025,7 +3050,8 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat(
30253050

30263051
const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k;
30273052
const int index_y = j * WARP_SIZE + (QR5_K*k) % WARP_SIZE;
3028-
return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
3053+
return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8,
3054+
x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
30293055
}
30303056

30313057
static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
@@ -3301,7 +3327,11 @@ template <bool need_check> static __global__ void mul_mat_q4_0(
33013327
#define MMQ_Y_Q4_1_PASCAL 64
33023328
#define NWARPS_Q4_1_PASCAL 8
33033329

3304-
template <bool need_check> static __global__ void mul_mat_q4_1(
3330+
template <bool need_check> static __global__ void
3331+
#if __CUDA_ARCH__ < CC_TURING
3332+
__launch_bounds__(WARP_SIZE*NWARPS_Q4_1_PASCAL, 2)
3333+
#endif // __CUDA_ARCH__ < CC_TURING
3334+
mul_mat_q4_1(
33053335
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
33063336
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
33073337

@@ -3471,7 +3501,11 @@ template <bool need_check> static __global__ void mul_mat_q2_K(
34713501
#define MMQ_Y_Q3_K_PASCAL 64
34723502
#define NWARPS_Q3_K_PASCAL 8
34733503

3474-
template <bool need_check> static __global__ void mul_mat_q3_K(
3504+
template <bool need_check> static __global__ void
3505+
#if __CUDA_ARCH__ < CC_TURING
3506+
__launch_bounds__(WARP_SIZE*NWARPS_Q3_K_PASCAL, 2)
3507+
#endif // __CUDA_ARCH__ < CC_TURING
3508+
mul_mat_q3_K(
34753509
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
34763510
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
34773511

@@ -3501,11 +3535,15 @@ template <bool need_check> static __global__ void mul_mat_q3_K(
35013535
#define MMQ_X_Q4_K_AMPERE 64
35023536
#define MMQ_Y_Q4_K_AMPERE 128
35033537
#define NWARPS_Q4_K_AMPERE 4
3504-
#define MMQ_X_Q4_K_PASCAL 32
3538+
#define MMQ_X_Q4_K_PASCAL 64
35053539
#define MMQ_Y_Q4_K_PASCAL 64
35063540
#define NWARPS_Q4_K_PASCAL 8
35073541

3508-
template <bool need_check> static __global__ void mul_mat_q4_K(
3542+
template <bool need_check> static __global__ void
3543+
#if __CUDA_ARCH__ < CC_TURING
3544+
__launch_bounds__(WARP_SIZE*NWARPS_Q4_K_PASCAL, 2)
3545+
#endif // __CUDA_ARCH__ < CC_TURING
3546+
mul_mat_q4_K(
35093547
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
35103548
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
35113549

@@ -3569,11 +3607,15 @@ template <bool need_check> static __global__ void mul_mat_q5_K(
35693607
#define MMQ_X_Q6_K_AMPERE 64
35703608
#define MMQ_Y_Q6_K_AMPERE 64
35713609
#define NWARPS_Q6_K_AMPERE 4
3572-
#define MMQ_X_Q6_K_PASCAL 32
3610+
#define MMQ_X_Q6_K_PASCAL 64
35733611
#define MMQ_Y_Q6_K_PASCAL 64
35743612
#define NWARPS_Q6_K_PASCAL 8
35753613

3576-
template <bool need_check> static __global__ void mul_mat_q6_K(
3614+
template <bool need_check> static __global__ void
3615+
#if __CUDA_ARCH__ < CC_TURING
3616+
__launch_bounds__(WARP_SIZE*NWARPS_Q6_K_PASCAL, 2)
3617+
#endif // __CUDA_ARCH__ < CC_TURING
3618+
mul_mat_q6_K(
35773619
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
35783620
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
35793621

0 commit comments

Comments
 (0)