Skip to content

Commit f63760f

Browse files
committed
shared memory illustration
1 parent b5ce497 commit f63760f

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

ggml-cuda.cu

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -225,16 +225,18 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
225225
}
226226
}
227227

228-
template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const void * vx, const float * y, float * dst, const int ncols) {
228+
template <int reduce_size> static __global__ void dequantize_mul_mat_q4_0(const void * vx, const float * y, float * dst, const int ncols) {
229229
const block_q4_0 * x = (const block_q4_0 *) vx;
230230

231231
const int row = blockIdx.x * 2 + threadIdx.y;
232232
const int tid = threadIdx.x;
233233

234-
float tmp = 0;
234+
__shared__ float full_tmp[reduce_size * 2]; // separate sum for each thread
235+
float* tmp = full_tmp + reduce_size * threadIdx.y;
236+
tmp[tid] = 0;
235237

236-
for (int i = 0; i < ncols/block_size; i += 2) {
237-
const int col = i*block_size + 2*tid;
238+
for (int i = 0; i < ncols/reduce_size; i += 2) {
239+
const int col = i*reduce_size + 2*tid;
238240

239241
// dequantize
240242
const float d = x[(row*ncols + col)/QK4_0].d;
@@ -250,16 +252,18 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
250252
const float v1 = (vi1 - 8)*d;
251253

252254
// matrix multiplication
253-
tmp += v0 * y[col + 0];
254-
tmp += v1 * y[col + 1];
255+
tmp[tid] += v0 * y[col + 0];
256+
tmp[tid] += v1 * y[col + 1];
255257
}
256258

257-
#pragma unroll
258-
for (int mask = 16; mask > 0; mask >>= 1) {
259-
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
259+
for (int s=reduce_size/2; s>0; s>>=1) {
260+
if (tid < s) {
261+
tmp[tid] += tmp[tid + s];
262+
}
263+
__syncthreads();
260264
}
261265
if (tid == 0) {
262-
dst[row] = tmp;
266+
dst[row] = tmp[0];
263267
}
264268
}
265269

0 commit comments

Comments
 (0)