Skip to content

Commit 800c963

Browse files
authored
Fix CUDA softmax by subtracting max value before exp (#2665)
1 parent deb7dfc commit 800c963

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

ggml-cuda.cu

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3979,24 +3979,29 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
39793979

39803980
// the CUDA soft max implementation differs from the CPU implementation
39813981
// instead of doubles floats are used
3982-
// values are also not normalized to the maximum value by subtracting it in the exponential function
3983-
// theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
39843982
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
39853983
const int row = blockDim.x*blockIdx.x + threadIdx.x;
39863984
const int block_size = blockDim.y;
39873985
const int tid = threadIdx.y;
39883986

3989-
float tmp = 0.0;
3987+
float max_val = -INFINITY;
39903988

3991-
for (int block_start = 0; block_start < ncols; block_start += block_size) {
3992-
const int col = block_start + tid;
3989+
for (int col = tid; col < ncols; col += block_size) {
3990+
const int i = row*ncols + col;
3991+
max_val = max(max_val, x[i]);
3992+
}
39933993

3994-
if (col >= ncols) {
3995-
break;
3996-
}
3994+
// find the max value in the block
3995+
#pragma unroll
3996+
for (int mask = 16; mask > 0; mask >>= 1) {
3997+
max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
3998+
}
3999+
4000+
float tmp = 0.f;
39974001

4002+
for (int col = tid; col < ncols; col += block_size) {
39984003
const int i = row*ncols + col;
3999-
const float val = expf(x[i]);
4004+
const float val = expf(x[i] - max_val);
40004005
tmp += val;
40014006
dst[i] = val;
40024007
}
@@ -4007,15 +4012,11 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol
40074012
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
40084013
}
40094014

4010-
for (int block_start = 0; block_start < ncols; block_start += block_size) {
4011-
const int col = block_start + tid;
4012-
4013-
if (col >= ncols) {
4014-
break;
4015-
}
4015+
const float inv_tmp = 1.f / tmp;
40164016

4017+
for (int col = tid; col < ncols; col += block_size) {
40174018
const int i = row*ncols + col;
4018-
dst[i] /= tmp;
4019+
dst[i] *= inv_tmp;
40194020
}
40204021
}
40214022

0 commit comments

Comments
 (0)