Skip to content

Commit b150abe

Browse files
committed
cuda : avoid warp_reduce for smax
1 parent b68a112 commit b150abe

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

ggml-cuda.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6621,7 +6621,6 @@ static __global__ void flash_attn_ext_f16(
66216621
M[j] = __hmax(M[j], s);
66226622
}
66236623

6624-
smax = warp_reduce_max(smax);
66256624
M[j] = warp_reduce_max(M[j]);
66266625

66276626
const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]);
@@ -6649,6 +6648,8 @@ static __global__ void flash_attn_ext_f16(
66496648
}
66506649
}
66516650

6651+
smax = warp_reduce_max(smax);
6652+
66526653
// skip -INF blocks
66536654
if (__hisinf(smax) == -1) {
66546655
continue;

0 commit comments

Comments
 (0)