@@ -3979,24 +3979,29 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
3979
3979
3980
3980
// the CUDA soft max implementation differs from the CPU implementation
3981
3981
// instead of doubles floats are used
3982
- // values are also not normalized to the maximum value by subtracting it in the exponential function
3983
- // theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
3984
3982
static __global__ void soft_max_f32 (const float * x, float * dst, const int ncols) {
3985
3983
const int row = blockDim .x *blockIdx .x + threadIdx .x ;
3986
3984
const int block_size = blockDim .y ;
3987
3985
const int tid = threadIdx .y ;
3988
3986
3989
- float tmp = 0.0 ;
3987
+ float max_val = -INFINITY ;
3990
3988
3991
- for (int block_start = 0 ; block_start < ncols; block_start += block_size) {
3992
- const int col = block_start + tid;
3989
+ for (int col = tid; col < ncols; col += block_size) {
3990
+ const int i = row*ncols + col;
3991
+ max_val = max (max_val, x[i]);
3992
+ }
3993
3993
3994
- if (col >= ncols) {
3995
- break ;
3996
- }
3994
+ // find the max value in the block
3995
+ #pragma unroll
3996
+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
3997
+ max_val = max (max_val, __shfl_xor_sync (0xffffffff , max_val, mask, 32 ));
3998
+ }
3999
+
4000
+ float tmp = 0 .f ;
3997
4001
4002
+ for (int col = tid; col < ncols; col += block_size) {
3998
4003
const int i = row*ncols + col;
3999
- const float val = expf (x[i]);
4004
+ const float val = expf (x[i] - max_val );
4000
4005
tmp += val;
4001
4006
dst[i] = val;
4002
4007
}
@@ -4007,15 +4012,11 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol
4007
4012
tmp += __shfl_xor_sync (0xffffffff , tmp, mask, 32 );
4008
4013
}
4009
4014
4010
- for (int block_start = 0 ; block_start < ncols; block_start += block_size) {
4011
- const int col = block_start + tid;
4012
-
4013
- if (col >= ncols) {
4014
- break ;
4015
- }
4015
+ const float inv_tmp = 1 .f / tmp;
4016
4016
4017
+ for (int col = tid; col < ncols; col += block_size) {
4017
4018
const int i = row*ncols + col;
4018
- dst[i] /= tmp ;
4019
+ dst[i] *= inv_tmp ;
4019
4020
}
4020
4021
}
4021
4022
0 commit comments