Skip to content

Commit a682474

Browse files
CUDA: fix FA tg at long context for CC >= 8.9 (#13852)
1 parent 26b79b6 commit a682474

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -623,8 +623,8 @@ static __global__ void flash_attn_combine_results(
623623
__builtin_assume(tid < D);
624624

625625
extern __shared__ float2 meta[];
626-
if (tid < 2*parallel_blocks) {
627-
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid];
626+
for (int i = tid; i < 2*parallel_blocks; i += D) {
627+
((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i];
628628
}
629629

630630
__syncthreads();

0 commit comments

Comments
 (0)