We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 26b79b6 commit a682474Copy full SHA for a682474
ggml/src/ggml-cuda/fattn-common.cuh
@@ -623,8 +623,8 @@ static __global__ void flash_attn_combine_results(
623
__builtin_assume(tid < D);
624
625
extern __shared__ float2 meta[];
626
- if (tid < 2*parallel_blocks) {
627
- ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid];
+ for (int i = tid; i < 2*parallel_blocks; i += D) {
+ ((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i];
628
}
629
630
__syncthreads();
0 commit comments