Skip to content

Fix CUDA softmax by subtracting max value before exp #2665

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 22, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3955,24 +3955,29 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int

// the CUDA soft max implementation differs from the CPU implementation
// instead of doubles floats are used
// values are also not normalized to the maximum value by subtracting it in the exponential function
// theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
const int row = blockDim.y*blockIdx.y + threadIdx.y;
const int block_size = blockDim.x;
const int tid = threadIdx.x;

float tmp = 0.0;
float max_val = -INFINITY;

for (int block_start = 0; block_start < ncols; block_start += block_size) {
const int col = block_start + tid;
for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col;
max_val = max(max_val, x[i]);
}

if (col >= ncols) {
break;
}
// find the max value in the block
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
}

float tmp = 0.f;

for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col;
const float val = expf(x[i]);
const float val = expf(x[i] - max_val);
tmp += val;
dst[i] = val;
}
Expand All @@ -3983,15 +3988,11 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}

for (int block_start = 0; block_start < ncols; block_start += block_size) {
const int col = block_start + tid;

if (col >= ncols) {
break;
}
const float inv_tmp = 1.f / tmp;

for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col;
dst[i] /= tmp;
dst[i] *= inv_tmp;
}
}

Expand Down