Skip to content

Commit 95e1888

Browse files
CUDA: fix misaligned synchronization in FA (#13469)
1 parent df84919 commit 95e1888

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
895895
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
896896
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
897897
}
898+
} else if (np > 1) {
899+
// Warps with threadIdx.y % np == 0 execute a __syncthreads() in the if branch.
900+
// Therefore, all other warps also need to execute a __syncthreads().
901+
// Otherwise the points at which warps synchronize with each other would become misaligned.
902+
__syncthreads();
898903
}
899904

900905
#pragma unroll

0 commit comments

Comments
 (0)