Skip to content

Commit bcefa03

Browse files
CUDA: fix MMQ stream-k rounding if ne00 % 128 != 0 (#8311)
1 parent 5a7447c commit bcefa03

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2305,8 +2305,11 @@ static __global__ void mul_mat_q(
23052305
const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y
23062306

23072307
// kbc == k block continuous, current index in continuous ijk space.
2308-
int64_t kbc = GGML_PAD((int64_t) blockIdx.x *blocks_per_ne00*ntx*nty / gridDim.x, blocks_per_warp);
2309-
const int64_t kbc_stop = GGML_PAD((int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x, blocks_per_warp);
2308+
int64_t kbc = (int64_t) blockIdx.x *blocks_per_ne00*ntx*nty / gridDim.x;
2309+
int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x;
2310+
2311+
kbc -= (kbc % blocks_per_ne00) % blocks_per_warp;
2312+
kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_warp;
23102313

23112314
// kb0 == k index when doing the matrix multiplication for an output tile.
23122315
int kb0_start = kbc % blocks_per_ne00;
@@ -2362,8 +2365,11 @@ static __global__ void mul_mat_q_stream_k_fixup(
23622365
const int bidx_stop = (blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq / (gridDim.y*gridDim.x) + 1;
23632366

23642367
for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) {
2365-
const int64_t kbc = GGML_PAD((int64_t) bidx *blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp);
2366-
const int64_t kbc_stop = GGML_PAD((int64_t)(bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp);
2368+
int64_t kbc = (int64_t) bidx *blocks_per_ne00*ntx*nty / block_num_mmq;
2369+
int64_t kbc_stop = (int64_t)(bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq;
2370+
2371+
kbc -= (kbc % blocks_per_ne00) % blocks_per_warp;
2372+
kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_warp;
23672373

23682374
// Skip fixup tile if the MMQ CUDA block never wrote anything to it:
23692375
if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) {

0 commit comments

Comments
 (0)