@@ -235,8 +235,8 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
235
235
__shared__ float tmp[block_size]; // separate sum for each thread
236
236
tmp[tid] = 0 ;
237
237
238
- for (int i = 0 ; i < ncols/block_size; i += 2 ) {
239
- const int col = i*block_size + 2 *tid;
238
+ for (int i = 0 ; i < ncols/block_size; i += 4 ) {
239
+ const int col = i*block_size + 4 *tid;
240
240
241
241
// dequantize
242
242
const float d0 = x[(row*ncols + col)/QK4_0].d ;
@@ -245,19 +245,21 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
245
245
const uint8_t * p0 = x[(row*ncols + col)/QK4_0].qs ;
246
246
const int8_t * p1 = y[col/QK8_0].qs ;
247
247
248
- const uint8_t vui0 = p0[((row*ncols + col)%QK4_0)/2 ];
248
+ const uint8_t vui00 = p0[((row*ncols + col)%QK4_0)/2 ];
249
+ const uint8_t vui01 = p0[((row*ncols + col + 2 )%QK4_0)/2 ];
249
250
const int vi10 = p1[(col + 0 )%QK8_0];
250
251
const int vi11 = p1[(col + 1 )%QK8_0];
252
+ const int vi12 = p1[(col + 2 )%QK8_0];
253
+ const int vi13 = p1[(col + 3 )%QK8_0];
251
254
252
- const int vi00 = vui0 & 0xF ;
253
- const int vi01 = vui0 >> 4 ;
254
-
255
- const float v0 = (vi00 - 8 )*vi10*d0*d1;
256
- const float v1 = (vi01 - 8 )*vi11*d0*d1;
255
+ const int vi00 = vui00 & 0xF ;
256
+ const int vi01 = vui00 >> 4 ;
257
+ const int vi02 = vui01 & 0xF ;
258
+ const int vi03 = vui01 >> 4 ;
257
259
258
260
// matrix multiplication
259
- tmp[tid] += v0 ;
260
- tmp[tid] += v1 ;
261
+ const int sumi = (vi00 - 8 )*vi10 + (vi01 - 8 )*vi11 + (vi02 - 8 )*vi12 + (vi03 - 8 )*vi13 ;
262
+ tmp[tid] += sumi*d0*d1 ;
261
263
}
262
264
263
265
// sum up partial sums and write back result
0 commit comments