File tree Expand file tree Collapse file tree 1 file changed +7
-11
lines changed Expand file tree Collapse file tree 1 file changed +7
-11
lines changed Original file line number Diff line number Diff line change @@ -231,8 +231,7 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
231
231
const int row = blockIdx .x ;
232
232
const int tid = threadIdx .x ;
233
233
234
- __shared__ float tmp[block_size]; // separate sum for each thread
235
- tmp[tid] = 0 ;
234
+ float tmp = 0 ;
236
235
237
236
for (int i = 0 ; i < ncols/block_size; i += 2 ) {
238
237
const int col = i*block_size + 2 *tid;
@@ -251,19 +250,16 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
251
250
const float v1 = (vi1 - 8 )*d;
252
251
253
252
// matrix multiplication
254
- tmp[tid] += v0 * y[col + 0 ];
255
- tmp[tid] += v1 * y[col + 1 ];
253
+ tmp += v0 * y[col + 0 ];
254
+ tmp += v1 * y[col + 1 ];
256
255
}
257
256
258
- // sum up partial sums and write back result
259
- for (int s=block_size/2 ; s>0 ; s>>=1 ) {
260
- if (tid < s) {
261
- tmp[tid] += tmp[tid + s];
262
- }
263
- __syncthreads ();
257
+ #pragma unroll
258
+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
259
+ tmp += __shfl_xor_sync (0xffffffff , tmp, mask, 32 );
264
260
}
265
261
if (tid == 0 ) {
266
- dst[row] = tmp[ 0 ] ;
262
+ dst[row] = tmp;
267
263
}
268
264
}
269
265
You can’t perform that action at this time.
0 commit comments