@@ -5327,6 +5327,10 @@ static __global__ void mul_mat_vec_q(
5327
5327
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2 ;
5328
5328
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
5329
5329
5330
+ const int tid = WARP_SIZE*threadIdx .y + threadIdx .x ;
5331
+ const int row0 = rows_per_cuda_block*blockIdx .x ;
5332
+ const int blocks_per_row_x = ncols_x / qk;
5333
+ const int blocks_per_col_y = nrows_y / QK8_1;
5330
5334
constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
5331
5335
5332
5336
// partial sum for each thread
@@ -5335,18 +5339,18 @@ static __global__ void mul_mat_vec_q(
5335
5339
const block_q_t * x = (const block_q_t *) vx;
5336
5340
const block_q8_1 * y = (const block_q8_1 *) vy;
5337
5341
5338
- for (int kbx = (WARP_SIZE* threadIdx . y + threadIdx . x ) / (qi/vdr); kbx < (ncols_x / qk) ; kbx += blocks_per_iter) {
5342
+ for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x ; kbx += blocks_per_iter) {
5339
5343
const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
5340
5344
5341
5345
// x block quant index when casting the quants to int
5342
- const int kqs = vdr * ((WARP_SIZE* threadIdx . y + threadIdx . x ) % (qi/vdr));
5346
+ const int kqs = vdr * (tid % (qi/vdr));
5343
5347
5344
5348
#pragma unroll
5345
5349
for (int j = 0 ; j < ncols_y; ++j) {
5346
5350
#pragma unroll
5347
5351
for (int i = 0 ; i < rows_per_cuda_block; ++i) {
5348
5352
tmp[j][i] += vec_dot_q_cuda (
5349
- &x[kbx + (rows_per_cuda_block* blockIdx . x + i)*(ncols_x / qk) ], &y[j*(nrows_y / QK8_1) + kby], kqs);
5353
+ &x[kbx + (row0 + i)*blocks_per_row_x ], &y[j*blocks_per_col_y + kby], kqs);
5350
5354
}
5351
5355
}
5352
5356
}
@@ -5379,7 +5383,7 @@ static __global__ void mul_mat_vec_q(
5379
5383
}
5380
5384
5381
5385
if (threadIdx .x < rows_per_cuda_block) {
5382
- dst[j*nrows_dst + rows_per_cuda_block* blockIdx . x + threadIdx .x ] = tmp[j][threadIdx .x ];
5386
+ dst[j*nrows_dst + row0 + threadIdx .x ] = tmp[j][threadIdx .x ];
5383
5387
}
5384
5388
}
5385
5389
}
0 commit comments