@@ -6597,8 +6597,8 @@ static __global__ void flash_attn_ext_f16(
6597
6597
smax = warp_reduce_max (__hmax (smax, s));
6598
6598
M[j] = warp_reduce_max (__hmax (M[j], s));
6599
6599
6600
- const half ms = __hisinf (m) ? __float2half (0 .0f ) : hexp (m - M[j]);
6601
- const half vs = __hisinf (s) ? __float2half (0 .0f ) : hexp (s - M[j]);
6600
+ const half ms = __hisinf (m) == - 1 ? __float2half (0 .0f ) : hexp (m - M[j]);
6601
+ const half vs = __hisinf (s) == - 1 ? __float2half (0 .0f ) : hexp (s - M[j]);
6602
6602
6603
6603
S[j] = S[j]*ms + warp_reduce_sum (vs);
6604
6604
@@ -6624,7 +6624,7 @@ static __global__ void flash_attn_ext_f16(
6624
6624
smax = warp_reduce_max (smax);
6625
6625
M[j] = warp_reduce_max (M[j]);
6626
6626
6627
- const half ms = __hisinf (m) ? __float2half (0 .0f ) : hexp (m - M[j]);
6627
+ const half ms = __hisinf (m) == - 1 ? __float2half (0 .0f ) : hexp (m - M[j]);
6628
6628
6629
6629
// create a QxQ diagonal matrix for rescaling the output
6630
6630
if (lane_id == j) {
@@ -6637,7 +6637,7 @@ static __global__ void flash_attn_ext_f16(
6637
6637
for (int64_t p = lane_id; p < C; p += NW) {
6638
6638
const half s = ss[j*T + p];
6639
6639
6640
- const half vs = __hisinf (s) ? __float2half (0 .0f ) : hexp (s - M[j]);
6640
+ const half vs = __hisinf (s) == - 1 ? __float2half (0 .0f ) : hexp (s - M[j]);
6641
6641
6642
6642
ls += vs;
6643
6643
@@ -6650,7 +6650,7 @@ static __global__ void flash_attn_ext_f16(
6650
6650
}
6651
6651
6652
6652
// skip -INF blocks
6653
- if (__hisinf (smax)) {
6653
+ if (__hisinf (smax) == - 1 ) {
6654
6654
continue ;
6655
6655
}
6656
6656
@@ -6735,8 +6735,8 @@ static __global__ void flash_attn_ext_f16(
6735
6735
6736
6736
M = __hmax (M0, M1);
6737
6737
6738
- const half ms0 = __hisinf (M0) ? __float2half (0 .0f ) : hexp (M0 - M);
6739
- const half ms1 = __hisinf (M1) ? __float2half (0 .0f ) : hexp (M1 - M);
6738
+ const half ms0 = __hisinf (M0) == - 1 ? __float2half (0 .0f ) : hexp (M0 - M);
6739
+ const half ms1 = __hisinf (M1) == - 1 ? __float2half (0 .0f ) : hexp (M1 - M);
6740
6740
6741
6741
S = S0*ms0 + S1*ms1;
6742
6742
0 commit comments