Skip to content

Commit 4947778

Browse files
committed
Fix some more int overflow in softmax.
1 parent 9acb43d commit 4947778

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

ggml-cuda/softmax.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
2828
extern __shared__ float data_soft_max_f32[];
2929
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
3030
// shared memory buffer to cache values between iterations:
31-
float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols;
31+
float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols;
3232

3333
float max_val = -INFINITY;
3434

@@ -40,8 +40,8 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
4040
break;
4141
}
4242

43-
const int ix = rowx*ncols + col;
44-
const int iy = rowy*ncols + col;
43+
const int64_t ix = (int64_t)rowx*ncols + col;
44+
const int64_t iy = (int64_t)rowy*ncols + col;
4545

4646
const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
4747

@@ -109,7 +109,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
109109
return;
110110
}
111111

112-
const int idst = rowx*ncols + col;
112+
const int64_t idst = (int64_t)rowx*ncols + col;
113113
dst[idst] = vals[col] * inv_sum;
114114
}
115115
}

0 commit comments

Comments
 (0)