Skip to content

Commit bf10e13

Browse files
CUDA: int8 tensor cores for MMQ (legacy quants)
1 parent e05ea1c commit bf10e13

File tree

7 files changed

+546
-55
lines changed

7 files changed

+546
-55
lines changed

ggml-cuda/common.cuh

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@
139139
#define CC_PASCAL 600
140140
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
141141
#define CC_VOLTA 700
142+
#define CC_TURING 750
142143
#define CC_AMPERE 800
143144
#define CC_OFFSET_AMD 1000000
144145
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
@@ -326,9 +327,17 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int
326327
#endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
327328
#endif // defined(GGML_USE_HIPBLAS)
328329

329-
#define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
330+
#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
331+
#define FP16_AVAILABLE
332+
#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
330333

331-
#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
334+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
335+
#define FP16_MMA_AVAILABLE
336+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
337+
338+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
339+
#define INT8_MMA_AVAILABLE
340+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
332341

333342
static bool fast_fp16_available(const int cc) {
334343
return cc >= CC_PASCAL && cc != 610;
@@ -338,6 +347,10 @@ static bool fp16_mma_available(const int cc) {
338347
return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
339348
}
340349

350+
static bool int8_mma_available(const int cc) {
351+
return cc < CC_OFFSET_AMD && cc >= CC_TURING;
352+
}
353+
341354
[[noreturn]]
342355
static __device__ void no_device_code(
343356
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
@@ -379,7 +392,7 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
379392
}
380393

381394
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
382-
#if FP16_AVAILABLE
395+
#ifdef FP16_AVAILABLE
383396

384397
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
385398
#pragma unroll
@@ -412,7 +425,7 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
412425
}
413426

414427
static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
415-
#if FP16_AVAILABLE
428+
#ifdef FP16_AVAILABLE
416429

417430
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
418431
return __float2half(fmaxf(__half2float(a), __half2float(b)));

ggml-cuda/fattn-common.cuh

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
7474

7575
const int sumi = __dp4a(v, u, 0);
7676

77-
#if FP16_AVAILABLE
77+
#ifdef FP16_AVAILABLE
7878
if (std::is_same<T, half>::value) {
7979
const half2 * Q_ds = (const half2 *) Q_ds_v;
8080

@@ -122,7 +122,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
122122

123123
const int sumi = __dp4a(v, u, 0);
124124

125-
#if FP16_AVAILABLE
125+
#ifdef FP16_AVAILABLE
126126
if (std::is_same<T, half>::value) {
127127
const half2 * Q_ds = (const half2 *) Q_ds_v;
128128

@@ -181,7 +181,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
181181

182182
const int sumi = __dp4a(v, u, 0);
183183

184-
#if FP16_AVAILABLE
184+
#ifdef FP16_AVAILABLE
185185
if (std::is_same<T, half>::value) {
186186
const half2 * Q_ds = (const half2 *) Q_ds_v;
187187

@@ -236,7 +236,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
236236

237237
const int sumi = __dp4a(v, u, 0);
238238

239-
#if FP16_AVAILABLE
239+
#ifdef FP16_AVAILABLE
240240
if (std::is_same<T, half>::value) {
241241
const half2 * Q_ds = (const half2 *) Q_ds_v;
242242

@@ -314,7 +314,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
314314
GGML_UNUSED(Q_q8);
315315
GGML_UNUSED(Q_ds_v);
316316

317-
#if FP16_AVAILABLE
317+
#ifdef FP16_AVAILABLE
318318
if (std::is_same<T, half>::value) {
319319
const half2 * Q_h2 = (const half2 *) Q_v;
320320

@@ -407,7 +407,7 @@ static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__
407407
const int q0 = x[ib].qs[iqs];
408408
const int q = ((q0 >> (4*shift)) & 0x0F) - 8;
409409

410-
#if FP16_AVAILABLE
410+
#ifdef FP16_AVAILABLE
411411
if (std::is_same<T, half>::value) {
412412
return ((half) d)*((half) q);
413413
}
@@ -428,7 +428,7 @@ static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__
428428
const int q0 = x[ib].qs[iqs];
429429
const int q = ((q0 >> (4*shift)) & 0x0F);
430430

431-
#if FP16_AVAILABLE
431+
#ifdef FP16_AVAILABLE
432432
if (std::is_same<T, half>::value) {
433433
return __low2half(dm)*((half) q) + __high2half(dm);
434434
}
@@ -453,7 +453,7 @@ static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__
453453
const int qh = ((qh0 >> idq) << 4) & 0x10;
454454
const int q = (ql | qh) - 16;
455455

456-
#if FP16_AVAILABLE
456+
#ifdef FP16_AVAILABLE
457457
if (std::is_same<T, half>::value) {
458458
return ((half) d)*((half) q);
459459
}
@@ -478,7 +478,7 @@ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__
478478
const int qh = ((qh0 >> idq) << 4) & 0x10;
479479
const int q = (ql | qh);
480480

481-
#if FP16_AVAILABLE
481+
#ifdef FP16_AVAILABLE
482482
if (std::is_same<T, half>::value) {
483483
return __low2half(dm)*((half) q) + __high2half(dm);
484484
}
@@ -497,7 +497,7 @@ static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__
497497
const T d = x[ib].d;
498498
const int q = x[ib].qs[iqs];
499499

500-
#if FP16_AVAILABLE
500+
#ifdef FP16_AVAILABLE
501501
if (std::is_same<T, half>::value) {
502502
return ((half) d)*((half) q);
503503
}

ggml-cuda/fattn-tile-f16.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ static __global__ void flash_attn_tile_ext_f16(
4343
const int ne1,
4444
const int ne2,
4545
const int ne3) {
46-
#if FP16_AVAILABLE
46+
#ifdef FP16_AVAILABLE
4747
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
4848

4949
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.

ggml-cuda/fattn-vec-f16.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ static __global__ void flash_attn_vec_ext_f16(
4040
const int ne1,
4141
const int ne2,
4242
const int ne3) {
43-
#if FP16_AVAILABLE
43+
#ifdef FP16_AVAILABLE
4444
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
4545

4646
constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);

ggml-cuda/fattn-wmma-f16.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#include "common.cuh"
22
#include "fattn-common.cuh"
33

4-
#if FP16_MMA_AVAILABLE
4+
#ifdef FP16_MMA_AVAILABLE
55
#include <mma.h>
6-
#endif
6+
#endif // FP16_MMA_AVAILABLE
77

88
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
99
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
@@ -45,7 +45,7 @@ static __global__ void flash_attn_ext_f16(
4545
const int ne1,
4646
const int ne2,
4747
const int ne3) {
48-
#if FP16_MMA_AVAILABLE
48+
#ifdef FP16_MMA_AVAILABLE
4949
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
5050

5151
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.

ggml-cuda/mma.cuh

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#include "common.cuh"
2+
3+
struct mma_int_A_I16K8 {
4+
static constexpr int I = 16;
5+
static constexpr int K = 8;
6+
static constexpr int ne = 4;
7+
8+
int x[ne] = {0};
9+
10+
static __device__ __forceinline__ int get_i(const int l) {
11+
const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
12+
__builtin_assume(ret >= 0);
13+
__builtin_assume(ret < I);
14+
return ret;
15+
}
16+
17+
static __device__ __forceinline__ int get_k(const int l) {
18+
const int ret = (l/2) * (K/2) + threadIdx.x % (K/2);
19+
__builtin_assume(ret >= 0);
20+
__builtin_assume(ret < K);
21+
return ret;
22+
}
23+
};
24+
25+
struct mma_int_B_J8K8 {
26+
static constexpr int J = 8;
27+
static constexpr int K = 8;
28+
static constexpr int ne = 2;
29+
30+
int x[ne] = {0};
31+
32+
static __device__ __forceinline__ int get_j(const int /* l */) {
33+
const int ret = threadIdx.x / (K/2);
34+
__builtin_assume(ret >= 0);
35+
__builtin_assume(ret < J);
36+
return ret;
37+
}
38+
39+
static __device__ __forceinline__ int get_k(const int l) {
40+
const int ret = l * (K/2) + threadIdx.x % (K/2);
41+
__builtin_assume(ret >= 0);
42+
__builtin_assume(ret < K);
43+
return ret;
44+
}
45+
};
46+
47+
struct mma_int_C_I16J8 {
48+
static constexpr int I = 16;
49+
static constexpr int J = 8;
50+
static constexpr int ne = 4;
51+
52+
int x[ne] = {0};
53+
54+
static __device__ __forceinline__ int get_i(const int l) {
55+
const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
56+
__builtin_assume(ret >= 0);
57+
__builtin_assume(ret < I);
58+
return ret;
59+
}
60+
61+
static __device__ __forceinline__ int get_j(const int l) {
62+
const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
63+
__builtin_assume(ret >= 0);
64+
__builtin_assume(ret < J);
65+
return ret;
66+
}
67+
68+
__device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
69+
#ifdef INT8_MMA_AVAILABLE
70+
#if __CUDA_ARCH__ >= CC_AMPERE
71+
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
72+
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
73+
: "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1]));
74+
#else
75+
// On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
76+
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
77+
: "+r"(x[0]), "+r"(x[1])
78+
: "r"(mma_A.x[0]), "r"(mma_B.x[0]));
79+
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
80+
: "+r"(x[2]), "+r"(x[3])
81+
: "r"(mma_A.x[1]), "r"(mma_B.x[0]));
82+
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
83+
: "+r"(x[0]), "+r"(x[1])
84+
: "r"(mma_A.x[2]), "r"(mma_B.x[1]));
85+
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
86+
: "+r"(x[2]), "+r"(x[3])
87+
: "r"(mma_A.x[3]), "r"(mma_B.x[1]));
88+
#endif // __CUDA_ARCH__ >= CC_AMPERE
89+
#else
90+
GGML_UNUSED(mma_A);
91+
GGML_UNUSED(mma_B);
92+
NO_DEVICE_CODE;
93+
#endif // INT8_MMA_AVAILABLE
94+
}
95+
};

0 commit comments

Comments
 (0)