@@ -516,6 +516,104 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
516
516
nullptr ;
517
517
}
518
518
519
+ template <int D, int ncols, int KQ_stride> // D == head size
520
+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
521
+ __launch_bounds__ (D, 1 )
522
+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
523
+ static __global__ void flash_attn_stream_k_fixup (
524
+ float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
525
+ const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim .x *(2 *2 *ncols);
526
+
527
+ const int iter_k = ne11 / KQ_stride;
528
+ const int iter_j = (ne01 + (ncols - 1 )) / ncols;
529
+
530
+ const int bidx0 = blockIdx .x ;
531
+
532
+ const int kbc0 = (bidx0 + 0 )*iter_k*iter_j*ne02 / gridDim .x ;
533
+ const int kbc0_stop = (bidx0 + 1 )*iter_k*iter_j*ne02 / gridDim .x ;
534
+
535
+ const bool did_not_have_any_data = kbc0 == kbc0_stop;
536
+ const bool wrote_beginning_of_tile = kbc0 % iter_k == 0 ;
537
+ const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0 ;
538
+ if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
539
+ return ;
540
+ }
541
+
542
+ const int channel = kbc0 / (iter_k*iter_j);
543
+ const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
544
+
545
+ dst += jt*ncols*ne02*D + channel*D;
546
+
547
+ // Load the partial result that needs a fixup:
548
+ float dst_val[ncols] = {0 .0f };
549
+ float max_val[ncols] = {0 .0f };
550
+ float rowsum[ncols] = {0 .0f };
551
+ #pragma unroll
552
+ for (int j = 0 ; j < ncols; ++j) {
553
+ if (jt*ncols + j >= ne01) {
554
+ break ;
555
+ }
556
+ dst_val[j] = dst[j*ne02*D + threadIdx .x ];
557
+
558
+ const float2 tmp = dst_fixup[bidx0*ncols + j];
559
+ max_val[j] = tmp.x ;
560
+ rowsum[j] = tmp.y ;
561
+ }
562
+
563
+ // Iterate over previous blocks and compute the combined results.
564
+ // All CUDA blocks that get here must have a previous block that needs a fixup.
565
+ int bidx = bidx0 - 1 ;
566
+ int kbc_stop = kbc0;
567
+ while (true ) {
568
+ const int kbc = bidx*iter_k*iter_j*ne02 / gridDim .x ;
569
+ if (kbc == kbc_stop) { // Did not have any data.
570
+ bidx--;
571
+ kbc_stop = kbc;
572
+ continue ;
573
+ }
574
+
575
+ #pragma unroll
576
+ for (int j = 0 ; j < ncols; ++j) {
577
+ if (jt*ncols + j >= ne01) {
578
+ break ;
579
+ }
580
+ const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx .x ];
581
+
582
+ const float2 tmp = dst_fixup[(gridDim .x + bidx)*ncols + j];
583
+
584
+ // Scale the current and new value accumulators depending on the max. values.
585
+ const float max_val_new = fmaxf (max_val[j], tmp.x );
586
+
587
+ const float diff_val = max_val[j] - max_val_new;
588
+ const float diff_add = tmp.x - max_val_new;
589
+
590
+ const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf (diff_val) : 0 .0f ;
591
+ const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf (diff_add) : 0 .0f ;
592
+
593
+ dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add;
594
+ rowsum[j] = scale_val*rowsum[j] + scale_add*tmp.y ;
595
+
596
+ max_val[j] = max_val_new;
597
+ }
598
+
599
+ // If this block started in a previous tile we are done and don't need to combine additional partial results.
600
+ if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
601
+ break ;
602
+ }
603
+ bidx--;
604
+ kbc_stop = kbc;
605
+ }
606
+
607
+ // Write back final result:
608
+ #pragma unroll
609
+ for (int j = 0 ; j < ncols; ++j) {
610
+ if (jt*ncols + j >= ne01) {
611
+ return ;
612
+ }
613
+ dst[j*ne02*D + threadIdx .x ] = dst_val[j] / rowsum[j];
614
+ }
615
+ }
616
+
519
617
template <int D, int parallel_blocks> // D == head size
520
618
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
521
619
__launch_bounds__ (D, 1 )
@@ -581,10 +679,11 @@ static void on_no_fattn_vec_case(const int D) {
581
679
}
582
680
}
583
681
584
- template <int D, int parallel_blocks>
682
+ // parallel_blocks == 0 is stream-k decomposition
683
+ template <int D, int cols_per_block, int parallel_blocks, int KQ_stride>
585
684
void launch_fattn (
586
685
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
587
- const int nwarps, const int cols_per_block , const bool need_f16_K, const bool need_f16_V
686
+ const int nwarps, const size_t nbytes_shared , const bool need_f16_K, const bool need_f16_V
588
687
) {
589
688
const ggml_tensor * Q = dst->src [0 ];
590
689
const ggml_tensor * K = dst->src [1 ];
@@ -603,20 +702,23 @@ void launch_fattn(
603
702
604
703
GGML_ASSERT (K->ne [1 ] % FATTN_KQ_STRIDE == 0 && " Incorrect KV cache padding." );
605
704
705
+ GGML_ASSERT (Q->ne [3 ] == 1 );
706
+
606
707
ggml_cuda_pool & pool = ctx.pool ();
607
708
cudaStream_t main_stream = ctx.stream ();
709
+ const int nsm = ggml_cuda_info ().devices [ggml_cuda_get_device ()].nsm ;
608
710
609
711
ggml_cuda_pool_alloc<half> K_f16 (pool);
610
712
ggml_cuda_pool_alloc<half> V_f16 (pool);
611
713
ggml_cuda_pool_alloc<float > dst_tmp (pool);
612
714
ggml_cuda_pool_alloc<float2 > dst_tmp_meta (pool);
613
715
614
- char * K_data = (char *) K->data ;
716
+ const char * K_data = (const char *) K->data ;
615
717
size_t nb11 = K->nb [1 ];
616
718
size_t nb12 = K->nb [2 ];
617
719
size_t nb13 = K->nb [3 ];
618
720
619
- char * V_data = (char *) V->data ;
721
+ const char * V_data = (const char *) V->data ;
620
722
size_t nb21 = V->nb [1 ];
621
723
size_t nb22 = V->nb [2 ];
622
724
size_t nb23 = V->nb [3 ];
@@ -649,39 +751,60 @@ void launch_fattn(
649
751
nb23 = nb23*bs*sizeof (half)/ts;
650
752
}
651
753
652
- if (parallel_blocks > 1 ) {
653
- dst_tmp.alloc (parallel_blocks*ggml_nelements (KQV));
654
- dst_tmp_meta.alloc (parallel_blocks*ggml_nrows (KQV));
655
- }
754
+ const int ntiles_x = ((Q->ne [1 ] + cols_per_block - 1 ) / cols_per_block);
755
+ const int ntiles_total = ntiles_x*Q->ne [2 ]*Q->ne [3 ];
656
756
657
757
const dim3 block_dim (WARP_SIZE, nwarps, 1 );
658
- const dim3 blocks_num (parallel_blocks*((Q->ne [1 ] + cols_per_block - 1 ) / cols_per_block), Q->ne [2 ], Q->ne [3 ]);
659
- const int shmem = 0 ;
758
+ dim3 blocks_num;
759
+ if (parallel_blocks == 0 ) {
760
+ // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
761
+ const int tiles_nwaves = (ntiles_total - nsm - 1 ) / nsm;
762
+ const bool tiles_inefficient = 3 *nsm < 2 *tiles_nwaves*ntiles_total;
763
+ const bool short_context = K->ne [1 ] < 4096 ;
764
+
765
+ const int nblocks_stream_k = 2 *nsm;
766
+
767
+ blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k;
768
+ blocks_num.y = 1 ;
769
+ blocks_num.z = 1 ;
770
+
771
+ dst_tmp_meta.alloc (blocks_num.x *cols_per_block * (2 *2 + D) * sizeof (float ));
772
+ } else {
773
+ blocks_num.x = parallel_blocks*ntiles_x;
774
+ blocks_num.y = Q->ne [2 ];
775
+ blocks_num.z = Q->ne [3 ];
776
+
777
+ if (parallel_blocks > 1 ) {
778
+ dst_tmp.alloc (parallel_blocks*ggml_nelements (KQV));
779
+ dst_tmp_meta.alloc (parallel_blocks*ggml_nrows (KQV));
780
+ }
781
+ }
782
+
660
783
661
784
float scale = 1 .0f ;
662
785
float max_bias = 0 .0f ;
663
786
float logit_softcap = 0 .0f ;
664
787
665
- memcpy (&scale, (float *) KQV->op_params + 0 , sizeof (float ));
666
- memcpy (&max_bias, (float *) KQV->op_params + 1 , sizeof (float ));
667
- memcpy (&logit_softcap, (float *) KQV->op_params + 2 , sizeof (float ));
788
+ memcpy (&scale, (const float *) KQV->op_params + 0 , sizeof (float ));
789
+ memcpy (&max_bias, (const float *) KQV->op_params + 1 , sizeof (float ));
790
+ memcpy (&logit_softcap, (const float *) KQV->op_params + 2 , sizeof (float ));
668
791
669
792
if (logit_softcap != 0 .0f ) {
670
793
scale /= logit_softcap;
671
794
}
672
795
673
796
const uint32_t n_head = Q->ne [2 ];
674
- const uint32_t n_head_log2 = 1u << ( uint32_t ) floorf (log2f (( float ) n_head));
797
+ const uint32_t n_head_log2 = 1u << uint32_t ( floorf (log2f (float ( n_head)) ));
675
798
676
799
const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
677
800
const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
678
801
679
- fattn_kernel<<<blocks_num, block_dim, shmem , main_stream>>> (
802
+ fattn_kernel<<<blocks_num, block_dim, nbytes_shared , main_stream>>> (
680
803
(const char *) Q->data ,
681
804
K_data,
682
805
V_data,
683
806
mask ? ((const char *) mask->data ) : nullptr ,
684
- (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp. ptr , dst_tmp_meta.ptr ,
807
+ (parallel_blocks) > 1 ? dst_tmp. ptr : (float *) KQV->data , dst_tmp_meta.ptr ,
685
808
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
686
809
Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
687
810
K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
@@ -693,16 +816,22 @@ void launch_fattn(
693
816
);
694
817
CUDA_CHECK (cudaGetLastError ());
695
818
696
- if ((parallel_blocks) == 1 ) {
697
- return ;
698
- }
819
+ if constexpr (parallel_blocks == 0 ) {
820
+ if (blocks_num.x % ntiles_total != 0 ) { // Fixup is only needed if the SMs work on fractional tiles.
821
+ const dim3 block_dim_combine (D, 1 , 1 );
822
+ const dim3 blocks_num_combine = blocks_num;
699
823
700
- const dim3 block_dim_combine (D, 1 , 1 );
701
- const dim3 blocks_num_combine (Q->ne [1 ], blocks_num.y , blocks_num.z );
702
- const int shmem_combine = 0 ;
824
+ flash_attn_stream_k_fixup<D, cols_per_block, KQ_stride>
825
+ <<<blocks_num_combine, block_dim_combine, 0 , main_stream>>>
826
+ ((float *) KQV->data , dst_tmp_meta.ptr , Q->ne [1 ], Q->ne [2 ], K->ne [1 ]);
827
+ }
828
+ } else if constexpr (parallel_blocks > 1 ) {
829
+ const dim3 block_dim_combine (D, 1 , 1 );
830
+ const dim3 blocks_num_combine (Q->ne [1 ], blocks_num.y , blocks_num.z );
703
831
704
- flash_attn_combine_results<D, parallel_blocks>
705
- <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
706
- (dst_tmp.ptr , dst_tmp_meta.ptr , (float *) KQV->data );
832
+ flash_attn_combine_results<D, parallel_blocks>
833
+ <<<blocks_num_combine, block_dim_combine, 0 , main_stream>>>
834
+ (dst_tmp.ptr , dst_tmp_meta.ptr , (float *) KQV->data );
835
+ }
707
836
CUDA_CHECK (cudaGetLastError ());
708
837
}
0 commit comments