Skip to content

Commit 44b3b3c

Browse files
committed
Get rid of shared memory
1 parent 3ed4588 commit 44b3b3c

File tree

1 file changed

+7
-11
lines changed

1 file changed

+7
-11
lines changed

ggml-cuda.cu

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,7 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
231231
const int row = blockIdx.x;
232232
const int tid = threadIdx.x;
233233

234-
__shared__ float tmp[block_size]; // separate sum for each thread
235-
tmp[tid] = 0;
234+
float tmp = 0;
236235

237236
for (int i = 0; i < ncols/block_size; i += 2) {
238237
const int col = i*block_size + 2*tid;
@@ -251,19 +250,16 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
251250
const float v1 = (vi1 - 8)*d;
252251

253252
// matrix multiplication
254-
tmp[tid] += v0 * y[col + 0];
255-
tmp[tid] += v1 * y[col + 1];
253+
tmp += v0 * y[col + 0];
254+
tmp += v1 * y[col + 1];
256255
}
257256

258-
// sum up partial sums and write back result
259-
for (int s=block_size/2; s>0; s>>=1) {
260-
if (tid < s) {
261-
tmp[tid] += tmp[tid + s];
262-
}
263-
__syncthreads();
257+
#pragma unroll
258+
for (int mask = 16; mask > 0; mask >>= 1) {
259+
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
264260
}
265261
if (tid == 0) {
266-
dst[row] = tmp[0];
262+
dst[row] = tmp;
267263
}
268264
}
269265

0 commit comments

Comments
 (0)