Skip to content

Commit 12e4284

Browse files
committed
Fix CUDA softmax by subtracting max value before exp
1 parent 1f0bccb commit 12e4284

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

ggml-cuda.cu

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3955,24 +3955,29 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
39553955

39563956
// the CUDA soft max implementation differs from the CPU implementation
39573957
// instead of doubles floats are used
3958-
// values are also not normalized to the maximum value by subtracting it in the exponential function
3959-
// theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
39603958
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
39613959
const int row = blockDim.y*blockIdx.y + threadIdx.y;
39623960
const int block_size = blockDim.x;
39633961
const int tid = threadIdx.x;
39643962

3965-
float tmp = 0.0;
3963+
float max_val = -INFINITY;
39663964

3967-
for (int block_start = 0; block_start < ncols; block_start += block_size) {
3968-
const int col = block_start + tid;
3965+
for (int col = tid; col < ncols; col += block_size) {
3966+
const int i = row*ncols + col;
3967+
max_val = max(max_val, x[i]);
3968+
}
39693969

3970-
if (col >= ncols) {
3971-
break;
3972-
}
3970+
// find the max value in the block
3971+
#pragma unroll
3972+
for (int mask = 16; mask > 0; mask >>= 1) {
3973+
max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
3974+
}
3975+
3976+
float tmp = 0.f;
39733977

3978+
for (int col = tid; col < ncols; col += block_size) {
39743979
const int i = row*ncols + col;
3975-
const float val = expf(x[i]);
3980+
const float val = expf(x[i] - max_val);
39763981
tmp += val;
39773982
dst[i] = val;
39783983
}
@@ -3983,15 +3988,11 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol
39833988
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
39843989
}
39853990

3986-
for (int block_start = 0; block_start < ncols; block_start += block_size) {
3987-
const int col = block_start + tid;
3988-
3989-
if (col >= ncols) {
3990-
break;
3991-
}
3991+
const float inv_tmp = 1.f / tmp;
39923992

3993+
for (int col = tid; col < ncols; col += block_size) {
39933994
const int i = row*ncols + col;
3994-
dst[i] /= tmp;
3995+
dst[i] *= inv_tmp;
39953996
}
39963997
}
39973998

0 commit comments

Comments
 (0)