@@ -225,16 +225,18 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
225
225
}
226
226
}
227
227
228
- template <int block_size > static __global__ void dequantize_mul_mat_q4_0 (const void * vx, const float * y, float * dst, const int ncols) {
228
+ template <int reduce_size > static __global__ void dequantize_mul_mat_q4_0 (const void * vx, const float * y, float * dst, const int ncols) {
229
229
const block_q4_0 * x = (const block_q4_0 *) vx;
230
230
231
231
const int row = blockIdx .x * 2 + threadIdx .y ;
232
232
const int tid = threadIdx .x ;
233
233
234
- float tmp = 0 ;
234
+ __shared__ float full_tmp[reduce_size * 2 ]; // separate sum for each thread
235
+ float * tmp = full_tmp + reduce_size * threadIdx .y ;
236
+ tmp[tid] = 0 ;
235
237
236
- for (int i = 0 ; i < ncols/block_size ; i += 2 ) {
237
- const int col = i*block_size + 2 *tid;
238
+ for (int i = 0 ; i < ncols/reduce_size ; i += 2 ) {
239
+ const int col = i*reduce_size + 2 *tid;
238
240
239
241
// dequantize
240
242
const float d = x[(row*ncols + col)/QK4_0].d ;
@@ -250,16 +252,18 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
250
252
const float v1 = (vi1 - 8 )*d;
251
253
252
254
// matrix multiplication
253
- tmp += v0 * y[col + 0 ];
254
- tmp += v1 * y[col + 1 ];
255
+ tmp[tid] += v0 * y[col + 0 ];
256
+ tmp[tid] += v1 * y[col + 1 ];
255
257
}
256
258
257
- #pragma unroll
258
- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
259
- tmp += __shfl_xor_sync (0xffffffff , tmp, mask, 32 );
259
+ for (int s=reduce_size/2 ; s>0 ; s>>=1 ) {
260
+ if (tid < s) {
261
+ tmp[tid] += tmp[tid + s];
262
+ }
263
+ __syncthreads ();
260
264
}
261
265
if (tid == 0 ) {
262
- dst[row] = tmp;
266
+ dst[row] = tmp[ 0 ] ;
263
267
}
264
268
}
265
269
0 commit comments