Skip to content

CUDA: fix MMQ stream-k rounding if ne00 % 128 != 0 #8311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions ggml/src/ggml-cuda/mmq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2305,8 +2305,11 @@ static __global__ void mul_mat_q(
const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y

// kbc == k block continuous, current index in continuous ijk space.
int64_t kbc = GGML_PAD((int64_t) blockIdx.x *blocks_per_ne00*ntx*nty / gridDim.x, blocks_per_warp);
const int64_t kbc_stop = GGML_PAD((int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x, blocks_per_warp);
int64_t kbc = (int64_t) blockIdx.x *blocks_per_ne00*ntx*nty / gridDim.x;
int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x;

kbc -= (kbc % blocks_per_ne00) % blocks_per_warp;
kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_warp;

// kb0 == k index when doing the matrix multiplication for an output tile.
int kb0_start = kbc % blocks_per_ne00;
Expand Down Expand Up @@ -2362,8 +2365,11 @@ static __global__ void mul_mat_q_stream_k_fixup(
const int bidx_stop = (blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq / (gridDim.y*gridDim.x) + 1;

for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) {
const int64_t kbc = GGML_PAD((int64_t) bidx *blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp);
const int64_t kbc_stop = GGML_PAD((int64_t)(bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp);
int64_t kbc = (int64_t) bidx *blocks_per_ne00*ntx*nty / block_num_mmq;
int64_t kbc_stop = (int64_t)(bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq;

kbc -= (kbc % blocks_per_ne00) % blocks_per_warp;
kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_warp;

// Skip fixup tile if the MMQ CUDA block never wrote anything to it:
if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) {
Expand Down
Loading