Skip to content

Commit b68a112

Browse files
committed
cuda : fix __hisinf() result check
1 parent 12eaa22 commit b68a112

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

ggml-cuda.cu

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6597,8 +6597,8 @@ static __global__ void flash_attn_ext_f16(
65976597
smax = warp_reduce_max(__hmax(smax, s));
65986598
M[j] = warp_reduce_max(__hmax(M[j], s));
65996599

6600-
const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]);
6601-
const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]);
6600+
const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]);
6601+
const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]);
66026602

66036603
S[j] = S[j]*ms + warp_reduce_sum(vs);
66046604

@@ -6624,7 +6624,7 @@ static __global__ void flash_attn_ext_f16(
66246624
smax = warp_reduce_max(smax);
66256625
M[j] = warp_reduce_max(M[j]);
66266626

6627-
const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]);
6627+
const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]);
66286628

66296629
// create a QxQ diagonal matrix for rescaling the output
66306630
if (lane_id == j) {
@@ -6637,7 +6637,7 @@ static __global__ void flash_attn_ext_f16(
66376637
for (int64_t p = lane_id; p < C; p += NW) {
66386638
const half s = ss[j*T + p];
66396639

6640-
const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]);
6640+
const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]);
66416641

66426642
ls += vs;
66436643

@@ -6650,7 +6650,7 @@ static __global__ void flash_attn_ext_f16(
66506650
}
66516651

66526652
// skip -INF blocks
6653-
if (__hisinf(smax)) {
6653+
if (__hisinf(smax) == -1) {
66546654
continue;
66556655
}
66566656

@@ -6735,8 +6735,8 @@ static __global__ void flash_attn_ext_f16(
67356735

67366736
M = __hmax(M0, M1);
67376737

6738-
const half ms0 = __hisinf(M0) ? __float2half(0.0f) : hexp(M0 - M);
6739-
const half ms1 = __hisinf(M1) ? __float2half(0.0f) : hexp(M1 - M);
6738+
const half ms0 = __hisinf(M0) == -1 ? __float2half(0.0f) : hexp(M0 - M);
6739+
const half ms1 = __hisinf(M1) == -1 ? __float2half(0.0f) : hexp(M1 - M);
67406740

67416741
S = S0*ms0 + S1*ms1;
67426742

0 commit comments

Comments
 (0)