Skip to content

Commit 9a590c8

Browse files
CUDA: optimize MMQ int8 tensor core performance (#8062)
* CUDA: optimize MMQ int8 tensor core performance * only a single get_mma_tile_x_k function * simplify code, make functions constexpr
1 parent 52fc870 commit 9a590c8

File tree

3 files changed

+879
-547
lines changed

3 files changed

+879
-547
lines changed

ggml-cuda/common.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,7 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
643643
static constexpr int qi = QI3_S;
644644
};
645645

646-
static int get_mmq_x_max_host(const int cc) {
646+
static constexpr int get_mmq_x_max_host(int cc) {
647647
#ifdef CUDA_USE_TENSOR_CORES
648648
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64;
649649
#else
@@ -652,7 +652,7 @@ static int get_mmq_x_max_host(const int cc) {
652652
}
653653

654654
// Round rows to this value for --split-mode row:
655-
static int get_mmq_y_host(const int cc) {
655+
static constexpr int get_mmq_y_host(int cc) {
656656
return cc >= CC_VOLTA ? 128 : 64;
657657
}
658658

ggml-cuda/mma.cuh

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,20 @@ struct mma_int_A_I16K4 {
2020
GGML_CUDA_ASSUME(ret < K);
2121
return ret;
2222
}
23+
24+
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
25+
#if defined(INT8_MMA_AVAILABLE)
26+
const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
27+
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
28+
: "+r"(x[0]), "+r"(x[1])
29+
: "l"(xs));
30+
#else
31+
#pragma unroll
32+
for (int l = 0; l < ne; ++l) {
33+
x[l] = xs0[get_i(l)*stride + get_k(l)];
34+
}
35+
#endif // defined(INT8_MMA_AVAILABLE)
36+
}
2337
};
2438

2539
struct mma_int_A_I16K8 {
@@ -42,6 +56,20 @@ struct mma_int_A_I16K8 {
4256
GGML_CUDA_ASSUME(ret < K);
4357
return ret;
4458
}
59+
60+
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
61+
#if defined(INT8_MMA_AVAILABLE)
62+
const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
63+
asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
64+
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
65+
: "l"(xs));
66+
#else
67+
#pragma unroll
68+
for (int l = 0; l < ne; ++l) {
69+
x[l] = xs0[get_i(l)*stride + get_k(l)];
70+
}
71+
#endif // defined(INT8_MMA_AVAILABLE)
72+
}
4573
};
4674

4775
struct mma_int_B_J8K4 {
@@ -64,6 +92,20 @@ struct mma_int_B_J8K4 {
6492
GGML_CUDA_ASSUME(ret < K);
6593
return ret;
6694
}
95+
96+
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
97+
#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
98+
const int * xs = xs0 + (threadIdx.x%J)*stride;
99+
asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
100+
: "+r"(x[0])
101+
: "l"(xs));
102+
#else
103+
#pragma unroll
104+
for (int l = 0; l < ne; ++l) {
105+
x[l] = xs0[get_j(l)*stride + get_k(l)];
106+
}
107+
#endif // defined(INT8_MMA_AVAILABLE)
108+
}
67109
};
68110

69111
struct mma_int_B_J8K8 {
@@ -86,6 +128,20 @@ struct mma_int_B_J8K8 {
86128
GGML_CUDA_ASSUME(ret < K);
87129
return ret;
88130
}
131+
132+
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
133+
#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
134+
const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
135+
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
136+
: "+r"(x[0]), "+r"(x[1])
137+
: "l"(xs));
138+
#else
139+
#pragma unroll
140+
for (int l = 0; l < ne; ++l) {
141+
x[l] = xs0[get_j(l)*stride + get_k(l)];
142+
}
143+
#endif // defined(INT8_MMA_AVAILABLE)
144+
}
89145
};
90146

91147
struct mma_int_C_I16J8 {

0 commit comments

Comments
 (0)