@@ -3955,24 +3955,29 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
3955
3955
3956
3956
// the CUDA soft max implementation differs from the CPU implementation
3957
3957
// instead of doubles floats are used
3958
- // values are also not normalized to the maximum value by subtracting it in the exponential function
3959
- // theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
3960
3958
static __global__ void soft_max_f32 (const float * x, float * dst, const int ncols) {
3961
3959
const int row = blockDim .y *blockIdx .y + threadIdx .y ;
3962
3960
const int block_size = blockDim .x ;
3963
3961
const int tid = threadIdx .x ;
3964
3962
3965
- float tmp = 0.0 ;
3963
+ float max_val = -INFINITY ;
3966
3964
3967
- for (int block_start = 0 ; block_start < ncols; block_start += block_size) {
3968
- const int col = block_start + tid;
3965
+ for (int col = tid; col < ncols; col += block_size) {
3966
+ const int i = row*ncols + col;
3967
+ max_val = max (max_val, x[i]);
3968
+ }
3969
3969
3970
- if (col >= ncols) {
3971
- break ;
3972
- }
3970
+ // find the max value in the block
3971
+ #pragma unroll
3972
+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
3973
+ max_val = max (max_val, __shfl_xor_sync (0xffffffff , max_val, mask, 32 ));
3974
+ }
3975
+
3976
+ float tmp = 0 .f ;
3973
3977
3978
+ for (int col = tid; col < ncols; col += block_size) {
3974
3979
const int i = row*ncols + col;
3975
- const float val = expf (x[i]);
3980
+ const float val = expf (x[i] - max_val );
3976
3981
tmp += val;
3977
3982
dst[i] = val;
3978
3983
}
@@ -3983,15 +3988,11 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol
3983
3988
tmp += __shfl_xor_sync (0xffffffff , tmp, mask, 32 );
3984
3989
}
3985
3990
3986
- for (int block_start = 0 ; block_start < ncols; block_start += block_size) {
3987
- const int col = block_start + tid;
3988
-
3989
- if (col >= ncols) {
3990
- break ;
3991
- }
3991
+ const float inv_tmp = 1 .f / tmp;
3992
3992
3993
+ for (int col = tid; col < ncols; col += block_size) {
3993
3994
const int i = row*ncols + col;
3994
- dst[i] /= tmp ;
3995
+ dst[i] *= inv_tmp ;
3995
3996
}
3996
3997
}
3997
3998
0 commit comments