Skip to content

Commit 2879780

Browse files
mul_mat_q8_0 with tensor cores
1 parent 16926df commit 2879780

File tree

1 file changed

+178
-6
lines changed

1 file changed

+178
-6
lines changed

ggml-cuda/mmq.cu

Lines changed: 178 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat(
486486
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
487487
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
488488

489-
__shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y];
489+
__shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y*4];
490490
__shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI8_0) + mmq_y/QI8_0];
491491

492492
*x_ql = tile_x_qs;
@@ -519,7 +519,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
519519

520520
const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx;
521521

522-
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx);
522+
x_ql[i * (WARP_SIZE + 4) + k] = get_int_from_int8(bxi->qs, kqsx);
523523
}
524524

525525
const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
@@ -547,9 +547,20 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat(
547547
const float * x_dmf = (const float *) x_dm;
548548
const float * y_df = (const float *) y_ds;
549549

550-
return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMQ>
551-
(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0],
552-
y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
550+
const int * v = &x_ql[i * (WARP_SIZE + 1) + k];
551+
const int * u = &y_qs[j * WARP_SIZE + k];
552+
const float d8_0 = x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0];
553+
const float d8_1 = y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1];
554+
555+
int sumi = 0;
556+
557+
#pragma unroll
558+
for (int i = 0; i < VDR_Q8_0_Q8_1_MMQ; ++i) {
559+
// SIMD dot product of quantized values
560+
sumi = __dp4a(v[i], u[i], sumi);
561+
}
562+
563+
return d8_0*d8_1 * sumi;
553564
}
554565

555566
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
@@ -1066,6 +1077,167 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
10661077
return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
10671078
}
10681079

1080+
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
1081+
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
1082+
static __device__ __forceinline__ void mul_mat_q_test(
1083+
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
1084+
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
1085+
1086+
const block_q_t * x = (const block_q_t *) vx;
1087+
const block_q8_1 * y = (const block_q8_1 *) vy;
1088+
1089+
const int blocks_per_row_x = ncols_x / qk;
1090+
const int blocks_per_col_y = nrows_y / QK8_1;
1091+
const int blocks_per_warp = WARP_SIZE / qi;
1092+
1093+
const int & ncols_dst = ncols_y;
1094+
1095+
const int row_dst_0 = blockIdx.x*mmq_y;
1096+
const int & row_x_0 = row_dst_0;
1097+
1098+
const int col_dst_0 = blockIdx.y*mmq_x;
1099+
const int & col_y_0 = col_dst_0;
1100+
1101+
int * tile_x_ql = nullptr;
1102+
half2 * tile_x_dm = nullptr;
1103+
int * tile_x_qh = nullptr;
1104+
int * tile_x_sc = nullptr;
1105+
1106+
allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc);
1107+
1108+
__shared__ int tile_y_qs[mmq_x * WARP_SIZE + mmq_x*4];
1109+
__shared__ half2 tile_y_ds[mmq_x * WARP_SIZE/QI8_1];
1110+
1111+
static_assert(mmq_x % (8*nwarps) == 0, "assert");
1112+
float sum[mmq_x/(8*nwarps)][mmq_y/16][4] = {{{0.0f}}};
1113+
1114+
for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
1115+
1116+
load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
1117+
threadIdx.y, nrows_x-row_x_0-1, threadIdx.x, blocks_per_row_x);
1118+
1119+
#pragma unroll
1120+
for (int ir = 0; ir < qr; ++ir) {
1121+
const int kqs = ir*WARP_SIZE + threadIdx.x;
1122+
const int kbxd = kqs / QI8_1;
1123+
1124+
#pragma unroll
1125+
for (int i = 0; i < mmq_x; i += nwarps) {
1126+
const int col_y_eff = min(col_y_0 + threadIdx.y + i, ncols_y-1); // to prevent out-of-bounds memory accesses
1127+
1128+
const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd];
1129+
1130+
const int index_y = (threadIdx.y + i) * (WARP_SIZE + 4) + kqs % WARP_SIZE;
1131+
tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1);
1132+
}
1133+
1134+
#pragma unroll
1135+
for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
1136+
const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x;
1137+
const int kby = threadIdx.x % (WARP_SIZE/QI8_1);
1138+
const int col_y_eff = min(col_y_0 + ids, ncols_y-1);
1139+
1140+
// if the sum is not needed it's faster to transform the scale to f32 ahead of time
1141+
const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE/QI8_1) + kby].ds;
1142+
half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby];
1143+
if (need_sum) {
1144+
*dsi_dst = *dsi_src;
1145+
} else {
1146+
float * dfi_dst = (float *) dsi_dst;
1147+
*dfi_dst = __low2float(*dsi_src);
1148+
}
1149+
}
1150+
1151+
__syncthreads();
1152+
1153+
const float * x_dmf = (const float *) tile_x_dm;
1154+
const float * y_df = (const float *) tile_y_ds;
1155+
1156+
static_assert(!need_sum);
1157+
static_assert(vdr == 32/sizeof(int));
1158+
#pragma unroll
1159+
for (int k0 = ir*WARP_SIZE/qr; k0 < (ir+1)*WARP_SIZE/qr; k0 += vdr) {
1160+
#pragma unroll
1161+
for (int j00 = 0; j00 < mmq_x; j00 += 8*nwarps) {
1162+
const int j0 = j00 + 8*threadIdx.y;
1163+
#pragma unroll
1164+
for (int i0 = 0; i0 < mmq_y; i0 += 16) {
1165+
int v[4];
1166+
#pragma unroll
1167+
for (int l = 0; l < 4; ++l) {
1168+
v[l] = tile_x_ql[(i0 + (l%2)*8 + threadIdx.x/4) * (WARP_SIZE + 4) + k0 + (l/2)*4 + threadIdx.x%4];
1169+
}
1170+
int u[2];
1171+
#pragma unroll
1172+
for (int l = 0; l < 2; ++l) {
1173+
u[l] = tile_y_qs[(j0 + threadIdx.x/4) * (WARP_SIZE + 4) + k0 + 4*l + threadIdx.x%4];
1174+
}
1175+
1176+
int sumi[4] = {0};
1177+
#if __CUDA_ARCH__ >= CC_AMPERE
1178+
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
1179+
: "+r"(sumi[0]), "+r"(sumi[1]), "+r"(sumi[2]), "+r"(sumi[3])
1180+
: "r"(v[0]), "r"(v[1]), "r"(v[2]), "r"(v[3]), "r"(u[0]), "r"(u[1]));
1181+
#else
1182+
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
1183+
: "+r"(sumi[0]), "+r"(sumi[1]), "+r"(sumi[2]), "+r"(sumi[3])
1184+
: "r"(v[0]), "r"(v[1]), "r"(u[0]));
1185+
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
1186+
: "+r"(sumi[0]), "+r"(sumi[1]), "+r"(sumi[2]), "+r"(sumi[3])
1187+
: "r"(v[2]), "r"(v[3]), "r"(u[1]));
1188+
#endif // __CUDA_ARCH__ >= CC_AMPERE
1189+
1190+
float d8_0[2];
1191+
#pragma unroll
1192+
for (int l = 0; l < 2; ++l) {
1193+
const int i = i0 + 8*l + threadIdx.x/4;
1194+
d8_0[l] = x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0];
1195+
}
1196+
float d8_1[2];
1197+
#pragma unroll
1198+
for (int l = 0; l < 2; ++l) {
1199+
const int j = j0 + 2*(threadIdx.x%4) + l;
1200+
d8_1[l] = y_df[j * (WARP_SIZE/QI8_1) + k0/QI8_1];
1201+
}
1202+
1203+
#pragma unroll
1204+
for (int l = 0; l < 4; ++l) {
1205+
sum[j00/(8*nwarps)][i0/16][l] += d8_0[l/2]*d8_1[l%2] * sumi[l];
1206+
}
1207+
}
1208+
}
1209+
}
1210+
1211+
__syncthreads();
1212+
}
1213+
}
1214+
1215+
#pragma unroll
1216+
for (int j00 = 0; j00 < mmq_x; j00 += 8*nwarps) {
1217+
const int j0 = j00 + 8*threadIdx.y + 2*(threadIdx.x%4);
1218+
#pragma unroll
1219+
for (int i00 = 0; i00 < mmq_y; i00 += 16) {
1220+
const int i0 = i00 + threadIdx.x/4;
1221+
1222+
#pragma unroll
1223+
for (int l = 0; l < 4; ++l) {
1224+
const int i = i0 + 8*(l/2);
1225+
const int j = j0 + (l%2);
1226+
1227+
const int row_dst = row_dst_0 + i;
1228+
const int col_dst = col_dst_0 + j;
1229+
if (row_dst >= nrows_dst) {
1230+
continue;
1231+
}
1232+
if (col_dst >= ncols_dst) {
1233+
continue;
1234+
}
1235+
dst[col_dst*nrows_dst + row_dst] = sum[j00/(8*nwarps)][i00/16][l];
1236+
}
1237+
}
1238+
}
1239+
}
1240+
10691241
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
10701242
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
10711243
static __device__ __forceinline__ void mul_mat_q(
@@ -1304,7 +1476,7 @@ template <bool need_check> static __global__ void
13041476
#if __CUDA_ARCH__ >= MIN_CC_DP4A
13051477
constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q8_0);
13061478

1307-
mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q8_0<arch_config.y>,
1479+
mul_mat_q_test<QK8_0, QR8_0, QI8_0, false, block_q8_0, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q8_0<arch_config.y>,
13081480
load_tiles_q8_0<arch_config.y, arch_config.nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
13091481
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
13101482
#else

0 commit comments

Comments
 (0)