@@ -228,7 +228,7 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
228
228
template <int block_size> static __global__ void dequantize_mul_mat_q4_0 (const void * vx, const float * y, float * dst, const int ncols) {
229
229
const block_q4_0 * x = (const block_q4_0 *) vx;
230
230
231
- const int row = blockIdx .x ;
231
+ const int row = blockIdx .x * 2 + threadIdx . y ;
232
232
const int tid = threadIdx .x ;
233
233
234
234
float tmp = 0 ;
@@ -305,9 +305,12 @@ static void dequantize_mul_mat_q4_0_cuda(const void * vx, const float * y, float
305
305
// }
306
306
// }
307
307
// dequantize_mul_mat_q4_0<<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
308
- const int block_size = 32 ;
309
- GGML_ASSERT (ncols % block_size == 0 );
310
- dequantize_mul_mat_q4_0<block_size><<<nrows, block_size, 0 , stream>>> (vx, y, dst, ncols);
308
+ const int reduce_size = 32 ;
309
+ const int rows_per_block = 2 ;
310
+ const dim3 block_size (reduce_size, rows_per_block, 1 );
311
+ GGML_ASSERT (nrows % rows_per_block == 0 );
312
+ GGML_ASSERT (ncols % reduce_size == 0 );
313
+ dequantize_mul_mat_q4_0<reduce_size><<<nrows / rows_per_block, block_size, 0 , stream>>> (vx, y, dst, ncols);
311
314
}
312
315
313
316
// TODO: optimize
0 commit comments