@@ -516,27 +516,25 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
516
516
nullptr ;
517
517
}
518
518
519
- // The HIP compiler for some reason complains that it can't unroll a loop because of the jt*ncols + j >= ne01 conditional.
520
- #ifdef __clang__
521
- #pragma clang diagnostic push
522
- #pragma clang diagnostic ignored "-Wpass-failed"
523
- #endif // __clang__
524
-
525
- template <int D, int ncols, int KQ_stride> // D == head size
526
- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
519
+ template <int D, int ncols1, int ncols2, int KQ_stride> // D == head size
527
520
__launch_bounds__ (D, 1 )
528
- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
529
521
static __global__ void flash_attn_stream_k_fixup(
530
522
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
531
- const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim .x *(2 *2 *ncols);
532
-
533
- const int iter_k = ne11 / KQ_stride;
534
- const int iter_j = (ne01 + (ncols - 1 )) / ncols;
523
+ constexpr int ncols = ncols1*ncols2;
535
524
536
525
const int bidx0 = blockIdx .x ;
526
+ const int j = blockIdx .y ;
527
+ const int c = blockIdx .z ;
528
+ const int jc = j*ncols2 + c;
529
+ const int tid = threadIdx .x ;
530
+
531
+ const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim .x *(2 *2 *ncols);
532
+
533
+ const int iter_k = ne11 / FATTN_KQ_STRIDE;
534
+ const int iter_j = (ne01 + (ncols1 - 1 )) / ncols1;
537
535
538
- const int kbc0 = (bidx0 + 0 )*iter_k*iter_j*ne02 / gridDim .x ;
539
- const int kbc0_stop = (bidx0 + 1 )*iter_k*iter_j*ne02 / gridDim .x ;
536
+ const int kbc0 = (bidx0 + 0 )*iter_k*iter_j*( ne02/ncols2) / gridDim .x ;
537
+ const int kbc0_stop = (bidx0 + 1 )*iter_k*iter_j*( ne02/ncols2) / gridDim .x ;
540
538
541
539
const bool did_not_have_any_data = kbc0 == kbc0_stop;
542
540
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0 ;
@@ -548,59 +546,53 @@ static __global__ void flash_attn_stream_k_fixup(
548
546
const int channel = kbc0 / (iter_k*iter_j);
549
547
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
550
548
551
- dst += jt*ncols*ne02*D + channel*D;
549
+ if (jt*ncols1 + j >= ne01) {
550
+ return ;
551
+ }
552
552
553
- // Load the partial result that needs a fixup:
554
- float dst_val[ncols] = {0 .0f };
555
- float max_val[ncols] = {0 .0f };
556
- float rowsum[ncols] = {0 .0f };
557
- #pragma unroll
558
- for (int j = 0 ; j < ncols; ++j) {
559
- if (jt*ncols + j >= ne01) {
560
- break ;
561
- }
562
- dst_val[j] = dst[j*ne02*D + threadIdx .x ];
553
+ dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
563
554
564
- const float2 tmp = dst_fixup[bidx0*ncols + j];
565
- max_val[j] = tmp.x ;
566
- rowsum[j] = tmp.y ;
555
+ // Load the partial result that needs a fixup:
556
+ float dst_val = 0 .0f ;
557
+ float max_val = 0 .0f ;
558
+ float rowsum = 0 .0f ;
559
+ {
560
+ dst_val = *dst;
561
+
562
+ const float2 tmp = dst_fixup[bidx0*ncols + jc];
563
+ max_val = tmp.x ;
564
+ rowsum = tmp.y ;
567
565
}
568
566
569
567
// Iterate over previous blocks and compute the combined results.
570
568
// All CUDA blocks that get here must have a previous block that needs a fixup.
571
569
int bidx = bidx0 - 1 ;
572
570
int kbc_stop = kbc0;
573
571
while (true ) {
574
- const int kbc = bidx*iter_k*iter_j*ne02 / gridDim .x ;
572
+ const int kbc = bidx*iter_k*iter_j*( ne02/ncols2) / gridDim .x ;
575
573
if (kbc == kbc_stop) { // Did not have any data.
576
574
bidx--;
577
575
kbc_stop = kbc;
578
576
continue ;
579
577
}
580
578
581
- #pragma unroll
582
- for (int j = 0 ; j < ncols; ++j) {
583
- if (jt*ncols + j >= ne01) {
584
- break ;
585
- }
586
- const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx .x ];
579
+ const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
587
580
588
- const float2 tmp = dst_fixup[(gridDim .x + bidx)*ncols + j ];
581
+ const float2 tmp = dst_fixup[(gridDim .x + bidx)*ncols + jc ];
589
582
590
- // Scale the current and new value accumulators depending on the max. values.
591
- const float max_val_new = fmaxf (max_val[j] , tmp.x );
583
+ // Scale the current and new value accumulators depending on the max. values.
584
+ const float max_val_new = fmaxf (max_val, tmp.x );
592
585
593
- const float diff_val = max_val[j] - max_val_new;
594
- const float diff_add = tmp.x - max_val_new;
586
+ const float diff_val = max_val - max_val_new;
587
+ const float diff_add = tmp.x - max_val_new;
595
588
596
- const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf (diff_val) : 0 .0f ;
597
- const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf (diff_add) : 0 .0f ;
589
+ const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf (diff_val) : 0 .0f ;
590
+ const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf (diff_add) : 0 .0f ;
598
591
599
- dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add;
600
- rowsum[j] = scale_val*rowsum[j] + scale_add*tmp.y ;
592
+ dst_val = scale_val*dst_val + scale_add*dst_add;
593
+ rowsum = scale_val*rowsum + scale_add*tmp.y ;
601
594
602
- max_val[j] = max_val_new;
603
- }
595
+ max_val = max_val_new;
604
596
605
597
// If this block started in a previous tile we are done and don't need to combine additional partial results.
606
598
if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
@@ -611,19 +603,9 @@ static __global__ void flash_attn_stream_k_fixup(
611
603
}
612
604
613
605
// Write back final result:
614
- #pragma unroll
615
- for (int j = 0 ; j < ncols; ++j) {
616
- if (jt*ncols + j >= ne01) {
617
- return ;
618
- }
619
- dst[j*ne02*D + threadIdx .x ] = dst_val[j] / rowsum[j];
620
- }
606
+ *dst = dst_val / rowsum;
621
607
}
622
608
623
- #ifdef __clang__
624
- #pragma clang diagnostic pop
625
- #endif // __clang__
626
-
627
609
template <int D, int parallel_blocks> // D == head size
628
610
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
629
611
__launch_bounds__ (D, 1 )
@@ -690,11 +672,13 @@ static void on_no_fattn_vec_case(const int D) {
690
672
}
691
673
692
674
// parallel_blocks == 0 is stream-k decomposition
693
- template <int D, int cols_per_block , int parallel_blocks, int KQ_stride>
675
+ template <int D, int ncols1, int ncols2 , int parallel_blocks, int KQ_stride>
694
676
void launch_fattn (
695
677
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
696
678
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
697
679
) {
680
+ constexpr int ncols = ncols1 * ncols2;
681
+
698
682
const ggml_tensor * Q = dst->src [0 ];
699
683
const ggml_tensor * K = dst->src [1 ];
700
684
const ggml_tensor * V = dst->src [2 ];
@@ -763,25 +747,26 @@ void launch_fattn(
763
747
nb23 = nb23*bs*sizeof (half)/ts;
764
748
}
765
749
766
- const int ntiles_x = ((Q->ne [1 ] + cols_per_block - 1 ) / cols_per_block );
767
- const int ntiles_total = ntiles_x* Q->ne [2 ]* Q->ne [3 ];
750
+ const int ntiles_x = ((Q->ne [1 ] + ncols1 - 1 ) / ncols1 );
751
+ const int ntiles_total = ntiles_x * ( Q->ne [2 ] / ncols2) * Q->ne [3 ];
768
752
769
753
const dim3 block_dim (WARP_SIZE, nwarps, 1 );
770
754
dim3 blocks_num;
771
755
if (parallel_blocks == 0 ) {
772
756
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
773
- const int tiles_nwaves = (ntiles_total + 2 *nsm - 1 ) / (2 *nsm);
774
- const int tiles_efficiency_percent = 100 * ntiles_total / (2 *nsm*tiles_nwaves);
757
+ const int max_blocks = 2 *nsm;
758
+ const int tiles_nwaves = (ntiles_total + max_blocks - 1 ) / max_blocks;
759
+ const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
775
760
776
- const int nblocks_stream_k = 2 *nsm ;
761
+ const int nblocks_stream_k = max_blocks ;
777
762
778
- const bool use_stream_k = tiles_efficiency_percent < 75 || cc >= GGML_CUDA_CC_ADA_LOVELACE ;
763
+ const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75 ;
779
764
780
765
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
781
766
blocks_num.y = 1 ;
782
767
blocks_num.z = 1 ;
783
768
784
- dst_tmp_meta.alloc (blocks_num.x *cols_per_block * (2 *2 + D) * sizeof (float ));
769
+ dst_tmp_meta.alloc (blocks_num.x *ncols * (2 *2 + D) * sizeof (float ));
785
770
} else {
786
771
blocks_num.x = parallel_blocks*ntiles_x;
787
772
blocks_num.y = Q->ne [2 ];
@@ -793,7 +778,6 @@ void launch_fattn(
793
778
}
794
779
}
795
780
796
-
797
781
float scale = 1 .0f ;
798
782
float max_bias = 0 .0f ;
799
783
float logit_softcap = 0 .0f ;
@@ -832,9 +816,9 @@ void launch_fattn(
832
816
if constexpr (parallel_blocks == 0 ) {
833
817
if (ntiles_total % blocks_num.x != 0 ) { // Fixup is only needed if the SMs work on fractional tiles.
834
818
const dim3 block_dim_combine (D, 1 , 1 );
835
- const dim3 blocks_num_combine = blocks_num;
819
+ const dim3 blocks_num_combine = { blocks_num. x , ncols1, ncols2} ;
836
820
837
- flash_attn_stream_k_fixup<D, cols_per_block , KQ_stride>
821
+ flash_attn_stream_k_fixup<D, ncols1, ncols2 , KQ_stride>
838
822
<<<blocks_num_combine, block_dim_combine, 0 , main_stream>>>
839
823
((float *) KQV->data , dst_tmp_meta.ptr , Q->ne [1 ], Q->ne [2 ], K->ne [1 ]);
840
824
}
0 commit comments