@@ -6495,8 +6495,8 @@ static __global__ void flash_attn_ext_f16(
6495
6495
half M[Q];
6496
6496
6497
6497
for (int i = 0 ; i < Q; i++) {
6498
- S[i] = 0 .0f ;
6499
- M[i] = -INFINITY;
6498
+ S[i] = __float2half ( 0 .0f ) ;
6499
+ M[i] = __float2half ( -INFINITY) ;
6500
6500
}
6501
6501
6502
6502
// assume K and V are same shape
@@ -6579,7 +6579,7 @@ static __global__ void flash_attn_ext_f16(
6579
6579
}
6580
6580
6581
6581
// used to detect blocks full of -INF
6582
- half smax = -INFINITY;
6582
+ half smax = __float2half ( -INFINITY) ;
6583
6583
6584
6584
// online softmax
6585
6585
if (C == 32 ) {
@@ -6592,8 +6592,8 @@ static __global__ void flash_attn_ext_f16(
6592
6592
smax = warp_reduce_max (__hmax (smax, s));
6593
6593
M[j] = warp_reduce_max (__hmax (M[j], s));
6594
6594
6595
- const half ms = __hisinf (m) ? 0 .0f : expf (m - M[j]);
6596
- const half vs = __hisinf (s) ? 0 .0f : expf (s - M[j]);
6595
+ const half ms = __hisinf (m) ? __float2half ( 0 .0f ) : hexp (m - M[j]);
6596
+ const half vs = __hisinf (s) ? __float2half ( 0 .0f ) : hexp (s - M[j]);
6597
6597
6598
6598
S[j] = S[j]*ms + warp_reduce_sum (vs);
6599
6599
@@ -6612,33 +6612,38 @@ static __global__ void flash_attn_ext_f16(
6612
6612
for (int64_t p = lane_id; p < C; p += NW) {
6613
6613
const half s = ss[j*T + p];
6614
6614
6615
- smax = warp_reduce_max ( __hmax (smax, s) );
6616
- M[j] = warp_reduce_max ( __hmax (M[j], s) );
6615
+ smax = __hmax (smax, s);
6616
+ M[j] = __hmax (M[j], s);
6617
6617
}
6618
6618
6619
- const half ms = __hisinf (m) ? 0 .0f : expf (m - M[j]);
6619
+ smax = warp_reduce_max (smax);
6620
+ M[j] = warp_reduce_max (M[j]);
6620
6621
6621
- S[j] = S [j]*ms ;
6622
+ const half ms = __hisinf (m) ? __float2half ( 0 . 0f ) : hexp (m - M [j]) ;
6622
6623
6623
6624
// create a QxQ diagonal matrix for rescaling the output
6624
6625
if (lane_id == j) {
6625
6626
ss[j*T + C + j] = ms;
6626
6627
}
6627
6628
6629
+ // local sum
6630
+ half ls = 0 .0f ;
6631
+
6628
6632
for (int64_t p = lane_id; p < C; p += NW) {
6629
6633
const half s = ss[j*T + p];
6630
6634
6631
- const half vs = __hisinf (s) ? 0 .0f : expf (s - M[j]);
6635
+ const half vs = __hisinf (s) ? __float2half ( 0 .0f ) : hexp (s - M[j]);
6632
6636
6633
- S[j] = S[j] + warp_reduce_sum (vs) ;
6637
+ ls += vs ;
6634
6638
6635
6639
// the P matrix from the paper (Q rows, C columns)
6636
6640
ss[j*T + p] = vs;
6637
6641
}
6642
+
6643
+ S[j] = S[j]*ms + warp_reduce_sum (ls);
6638
6644
}
6639
6645
}
6640
6646
6641
-
6642
6647
// skip -INF blocks
6643
6648
if (__hisinf (smax)) {
6644
6649
continue ;
@@ -6669,15 +6674,19 @@ static __global__ void flash_attn_ext_f16(
6669
6674
for (int cc = 0 ; cc < C/16 ; ++cc) {
6670
6675
const half * pv = (const half *) ((const char *) v + ((ic + 16 *cc)*nb21 + iv2*nb22 + iv3*nb23));
6671
6676
6677
+ half16x16_b mk[D16];
6672
6678
for (int64_t i = 0 ; i < D16; ++i) {
6673
- half16x16_b mk ;
6674
- nvcuda::wmma::load_matrix_sync (mk, pv + i* 16 , nb21/ sizeof (half));
6679
+ nvcuda::wmma::load_matrix_sync (mk[i], pv + i* 16 , nb21/ sizeof (half)) ;
6680
+ }
6675
6681
6676
- for (int64_t j = 0 ; j < Q16; ++j) {
6677
- half16x16_a mv;
6678
- nvcuda::wmma::load_matrix_sync (mv, ss + 16 *j*T + 16 *cc, T);
6682
+ half16x16_a mv[Q16];
6683
+ for (int64_t j = 0 ; j < Q16; ++j) {
6684
+ nvcuda::wmma::load_matrix_sync (mv[j], ss + 16 *j*T + 16 *cc, T);
6685
+ }
6679
6686
6680
- nvcuda::wmma::mma_sync (lo[j][i], mv, mk, lo[j][i]);
6687
+ for (int64_t j = 0 ; j < Q16; ++j) {
6688
+ for (int64_t i = 0 ; i < D16; ++i) {
6689
+ nvcuda::wmma::mma_sync (lo[j][i], mv[j], mk[i], lo[j][i]);
6681
6690
}
6682
6691
}
6683
6692
}
@@ -6695,8 +6704,8 @@ static __global__ void flash_attn_ext_f16(
6695
6704
6696
6705
// reduce the warps sequentially
6697
6706
for (int64_t sg = 1 ; sg < num_warps; ++sg) {
6698
- half S = 0 .0f ;
6699
- half M = -INFINITY;
6707
+ half S = __float2half ( 0 .0f ) ;
6708
+ half M = __float2half ( -INFINITY) ;
6700
6709
6701
6710
__syncthreads ();
6702
6711
@@ -6722,8 +6731,8 @@ static __global__ void flash_attn_ext_f16(
6722
6731
6723
6732
M = __hmax (M0, M1);
6724
6733
6725
- const half ms0 = __hisinf (M0) ? 0 .0f : expf (M0 - M);
6726
- const half ms1 = __hisinf (M1) ? 0 .0f : expf (M1 - M);
6734
+ const half ms0 = __hisinf (M0) ? __float2half ( 0 .0f ) : hexp (M0 - M);
6735
+ const half ms1 = __hisinf (M1) ? __float2half ( 0 .0f ) : hexp (M1 - M);
6727
6736
6728
6737
S = S0*ms0 + S1*ms1;
6729
6738
@@ -6770,8 +6779,6 @@ static __global__ void flash_attn_ext_f16(
6770
6779
}
6771
6780
}
6772
6781
6773
- // float2 * dst2 = (float2 *) dst;
6774
-
6775
6782
// final rescale with 1/S and store to global memory
6776
6783
if (warp_id == 0 ) {
6777
6784
for (int64_t j = 0 ; j < Q && iq1 + j < ne01; ++j) {
@@ -9637,7 +9644,7 @@ static void ggml_cuda_op_soft_max(
9637
9644
9638
9645
const int64_t ne00 = src0->ne [0 ];
9639
9646
const int64_t nrows_x = ggml_nrows (src0);
9640
- const int64_t nrows_y = src1 ? ggml_nrows (src1) : 1 ;
9647
+ const int64_t nrows_y = src1 ? src0-> ne [ 1 ] : 1 ; // note: using number of queries since mask can be padded!
9641
9648
9642
9649
float scale = 1 .0f ;
9643
9650
memcpy (&scale, dst->op_params , sizeof (float ));
@@ -10932,7 +10939,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
10932
10939
memcpy (&scale, KQV->op_params , sizeof (float ));
10933
10940
10934
10941
#define NQPB 16
10935
- #define NCPW 32
10942
+ #define NCPW 128
10936
10943
10937
10944
const int nqpb = NQPB; // queries per block
10938
10945
const int ncpw = NCPW; // cache values per warp (does not work for other values)
0 commit comments