Skip to content

CUDA: more warps for mmvq on NVIDIA #5394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 86 additions & 47 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5310,45 +5310,65 @@ template <bool need_check> static __global__ void
#endif // __CUDA_ARCH__ >= CC_VOLTA
}

template <int ncols_y_template, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
#define MMVQ_NWARPS_NVIDIA 4
#define MMVQ_NWARPS_AMD_RDNA2 1
#define MMVQ_NWARPS_AMD_OLD 4

template <int nwarps, int ncols_y_template, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1) // tells the compiler to use as many registers as it wants
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par, const int nrows_dst) {

const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par;

const int row = blockIdx.x*blockDim.y + threadIdx.y;

if (row >= nrows_x) {
return;
}
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
const int row = blockIdx.x;

const int blocks_per_row_x = ncols_x / qk;
const int blocks_per_col_y = nrows_y / QK8_1;
const int blocks_per_warp = vdr * WARP_SIZE / qi;
const int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;

// partial sum for each thread
float tmp[ncols_y_template != 0 ? ncols_y_template : 8] = {0.0f};

const block_q_t * x = (const block_q_t *) vx;
const block_q8_1 * y = (const block_q8_1 *) vy;

for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row_x; i += blocks_per_warp) {
for (int i = tid / (qi/vdr); i < blocks_per_row_x; i += blocks_per_iter) {
const int ibx = row*blocks_per_row_x + i; // x block index

const int iby = i * (qk/QK8_1); // y block index that aligns with ibx

const int iqs = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int
const int iqs = vdr * (tid % (qi/vdr)); // x block quant index when casting the quants to int

#pragma unroll
for (int j = 0; j < ncols_y; ++j) {
tmp[j] += vec_dot_q_cuda(&x[ibx], &y[j*blocks_per_col_y + iby], iqs);
}
}

__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y_template != 0 ? ncols_y_template : 8][WARP_SIZE];
if (threadIdx.y > 0) {
#pragma unroll
for (int j = 0; j < ncols_y; ++j) {
tmp_shared[threadIdx.y-1][j][threadIdx.x] = tmp[j];
}
}
__syncthreads();
if (threadIdx.y > 0) {
return;
}

// sum up partial sums and write back result
#pragma unroll
for (int j = 0; j < ncols_y; ++j) {
#pragma unroll
for (int i = 0; i < nwarps-1; ++i) {
tmp[j] += tmp_shared[i][j][threadIdx.x];
}
tmp[j] = warp_reduce_sum(tmp[j]);

if (threadIdx.x == 0) {
Expand Down Expand Up @@ -6833,46 +6853,65 @@ static void mul_mat_vec_q_cuda(
GGML_ASSERT(ncols_x % qk == 0);
GGML_ASSERT(ncols_y <= 4);

const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
switch (ncols_y) {
case 1:
mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
break;
case 2:
mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
break;
case 3:
mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
break;
case 4:
mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
break;
// case 5:
// mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
// break;
// case 6:
// mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
// break;
// case 7:
// mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
// break;
// case 8:
// mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
// break;
int id;
CUDA_CHECK(cudaGetDevice(&id));

int nwarps;
if (g_device_caps[id].cc >= CC_OFFSET_AMD) {
nwarps = g_device_caps[id].cc >= CC_RDNA2 ? MMVQ_NWARPS_AMD_RDNA2 : MMVQ_NWARPS_AMD_OLD;
} else {
nwarps = MMVQ_NWARPS_NVIDIA;
}

const dim3 block_nums(nrows_x, 1, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);

switch (nwarps) {
case 1: switch(ncols_y) {
case 1:
mul_mat_vec_q<1, 1, qk, qi, block_q_t, vdr, vec_dot>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
break;
case 2:
mul_mat_vec_q<1, 2, qk, qi, block_q_t, vdr, vec_dot>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
break;
case 3:
mul_mat_vec_q<1, 3, qk, qi, block_q_t, vdr, vec_dot>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
break;
case 4:
mul_mat_vec_q<1, 4, qk, qi, block_q_t, vdr, vec_dot>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
break;
default:
GGML_ASSERT(false);
break;
} break;
case 4: switch(ncols_y) {
case 1:
mul_mat_vec_q<4, 1, qk, qi, block_q_t, vdr, vec_dot>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
break;
case 2:
mul_mat_vec_q<4, 2, qk, qi, block_q_t, vdr, vec_dot>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
break;
case 3:
mul_mat_vec_q<4, 3, qk, qi, block_q_t, vdr, vec_dot>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
break;
case 4:
mul_mat_vec_q<4, 4, qk, qi, block_q_t, vdr, vec_dot>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
break;
default:
GGML_ASSERT(false);
break;
} break;

default:
GGML_ASSERT(false);
// mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot>
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
break;
}
}
Expand Down