Skip to content

Commit 76a0128

Browse files
revert low register pressure changes
1 parent 2bb97fc commit 76a0128

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

ggml-cuda.cu

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5327,6 +5327,10 @@ static __global__ void mul_mat_vec_q(
53275327
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
53285328
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
53295329

5330+
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
5331+
const int row0 = rows_per_cuda_block*blockIdx.x;
5332+
const int blocks_per_row_x = ncols_x / qk;
5333+
const int blocks_per_col_y = nrows_y / QK8_1;
53305334
constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
53315335

53325336
// partial sum for each thread
@@ -5335,18 +5339,18 @@ static __global__ void mul_mat_vec_q(
53355339
const block_q_t * x = (const block_q_t *) vx;
53365340
const block_q8_1 * y = (const block_q8_1 *) vy;
53375341

5338-
for (int kbx = (WARP_SIZE*threadIdx.y + threadIdx.x) / (qi/vdr); kbx < (ncols_x / qk); kbx += blocks_per_iter) {
5342+
for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
53395343
const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
53405344

53415345
// x block quant index when casting the quants to int
5342-
const int kqs = vdr * ((WARP_SIZE*threadIdx.y + threadIdx.x) % (qi/vdr));
5346+
const int kqs = vdr * (tid % (qi/vdr));
53435347

53445348
#pragma unroll
53455349
for (int j = 0; j < ncols_y; ++j) {
53465350
#pragma unroll
53475351
for (int i = 0; i < rows_per_cuda_block; ++i) {
53485352
tmp[j][i] += vec_dot_q_cuda(
5349-
&x[kbx + (rows_per_cuda_block*blockIdx.x + i)*(ncols_x / qk)], &y[j*(nrows_y / QK8_1) + kby], kqs);
5353+
&x[kbx + (row0 + i)*blocks_per_row_x], &y[j*blocks_per_col_y + kby], kqs);
53505354
}
53515355
}
53525356
}
@@ -5379,7 +5383,7 @@ static __global__ void mul_mat_vec_q(
53795383
}
53805384

53815385
if (threadIdx.x < rows_per_cuda_block) {
5382-
dst[j*nrows_dst + rows_per_cuda_block*blockIdx.x + threadIdx.x] = tmp[j][threadIdx.x];
5386+
dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
53835387
}
53845388
}
53855389
}

0 commit comments

Comments
 (0)