Skip to content

Commit b5ce497

Browse files
committed
increase the number of rows per block
1 parent 44b3b3c commit b5ce497

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

ggml-cuda.cu

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
228228
template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const void * vx, const float * y, float * dst, const int ncols) {
229229
const block_q4_0 * x = (const block_q4_0 *) vx;
230230

231-
const int row = blockIdx.x;
231+
const int row = blockIdx.x * 2 + threadIdx.y;
232232
const int tid = threadIdx.x;
233233

234234
float tmp = 0;
@@ -305,9 +305,12 @@ static void dequantize_mul_mat_q4_0_cuda(const void * vx, const float * y, float
305305
// }
306306
// }
307307
// dequantize_mul_mat_q4_0<<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
308-
const int block_size = 32;
309-
GGML_ASSERT(ncols % block_size == 0);
310-
dequantize_mul_mat_q4_0<block_size><<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
308+
const int reduce_size = 32;
309+
const int rows_per_block = 2;
310+
const dim3 block_size(reduce_size, rows_per_block, 1);
311+
GGML_ASSERT(nrows % rows_per_block == 0);
312+
GGML_ASSERT(ncols % reduce_size == 0);
313+
dequantize_mul_mat_q4_0<reduce_size><<<nrows / rows_per_block, block_size, 0, stream>>>(vx, y, dst, ncols);
311314
}
312315

313316
// TODO: optimize

0 commit comments

Comments
 (0)