@@ -6462,10 +6462,10 @@ static __global__ void flash_attn_ext_f16(
6462
6462
half16x16_acc lo[Q16][D16];
6463
6463
6464
6464
// load heads from Q to shared memory
6465
- for (int64_t j = warp_id; j < Q; j += num_warps) {
6465
+ for (int j = warp_id; j < Q; j += num_warps) {
6466
6466
const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
6467
6467
6468
- for (int64_t i = lane_id; i < D2; i += NW) {
6468
+ for (int i = lane_id; i < D2; i += NW) {
6469
6469
if (iq1 + j < ne01) {
6470
6470
sq2[j*T2 + i] = __float22half2_rn (q2[i]);
6471
6471
} else {
@@ -6477,15 +6477,15 @@ static __global__ void flash_attn_ext_f16(
6477
6477
nvcuda::wmma::fill_fragment (zr, 0.0 );
6478
6478
6479
6479
// zero out lo
6480
- for (int64_t j = 0 ; j < Q16; ++j) {
6481
- for (int64_t i = 0 ; i < D16; ++i) {
6480
+ for (int j = 0 ; j < Q16; ++j) {
6481
+ for (int i = 0 ; i < D16; ++i) {
6482
6482
nvcuda::wmma::fill_fragment (lo[j][i], 0.0 );
6483
6483
}
6484
6484
}
6485
6485
6486
6486
// zero out shared memory SH
6487
- for (int64_t j = 0 ; j < Q; ++j) {
6488
- for (int64_t i = lane_id; i < SH; i += NW) {
6487
+ for (int j = 0 ; j < Q; ++j) {
6488
+ for (int i = lane_id; i < SH; i += NW) {
6489
6489
ss[j*T + i] = 0.0 ;
6490
6490
}
6491
6491
}
@@ -6526,8 +6526,8 @@ static __global__ void flash_attn_ext_f16(
6526
6526
6527
6527
// load the queries from shared memory into local memory
6528
6528
half16x16_a mq[Q16][D16];
6529
- for (int64_t j = 0 ; j < Q16; ++j) {
6530
- for (int64_t i = 0 ; i < D16; ++i) {
6529
+ for (int j = 0 ; j < Q16; ++j) {
6530
+ for (int i = 0 ; i < D16; ++i) {
6531
6531
nvcuda::wmma::load_matrix_sync (mq[j][i], sq + 16 *j*T + i*16 , T);
6532
6532
}
6533
6533
}
@@ -6544,28 +6544,28 @@ static __global__ void flash_attn_ext_f16(
6544
6544
6545
6545
// loop over the KV cache
6546
6546
// each simdgroup handles blocks of Q rows and C columns
6547
- for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) {
6547
+ for (int ic = C*warp_id; ic < ne11; ic += C*num_warps) {
6548
6548
// Q*K^T
6549
6549
{
6550
6550
for (int cc = 0 ; cc < C/16 ; ++cc) {
6551
6551
half16x16_acc mqk[Q16];
6552
- for (int64_t j = 0 ; j < Q16; ++j) {
6552
+ for (int j = 0 ; j < Q16; ++j) {
6553
6553
nvcuda::wmma::fill_fragment (mqk[j], 0 );
6554
6554
}
6555
6555
6556
6556
const half * pk = (const half *) ((const char *) k + ((ic + 16 *cc)*nb11 + ik2*nb12 + ik3*nb13));
6557
6557
6558
- for (int64_t i = 0 ; i < D16; ++i) {
6558
+ for (int i = 0 ; i < D16; ++i) {
6559
6559
half16x16_bT mk; // transposed key
6560
6560
nvcuda::wmma::load_matrix_sync (mk, pk + i*16 , nb11/sizeof (half));
6561
6561
6562
- for (int64_t j = 0 ; j < Q16; ++j) {
6562
+ for (int j = 0 ; j < Q16; ++j) {
6563
6563
nvcuda::wmma::mma_sync (mqk[j], mq[j][i], mk, mqk[j]);
6564
6564
}
6565
6565
}
6566
6566
6567
6567
// mqk = mqk*scale + mask
6568
- for (int64_t j = 0 ; j < Q16; ++j) {
6568
+ for (int j = 0 ; j < Q16; ++j) {
6569
6569
half16x16_a mqka;
6570
6570
half16x16_acc mm;
6571
6571
@@ -6588,8 +6588,8 @@ static __global__ void flash_attn_ext_f16(
6588
6588
6589
6589
// online softmax
6590
6590
if (C == 32 ) {
6591
- for (int64_t j = 0 ; j < Q; ++j) {
6592
- const int64_t p = lane_id;
6591
+ for (int j = 0 ; j < Q; ++j) {
6592
+ const int p = lane_id;
6593
6593
6594
6594
const half m = M[j];
6595
6595
const half s = ss[j*T + p];
@@ -6611,10 +6611,10 @@ static __global__ void flash_attn_ext_f16(
6611
6611
ss[j*T + p] = vs;
6612
6612
}
6613
6613
} else {
6614
- for (int64_t j = 0 ; j < Q; ++j) {
6614
+ for (int j = 0 ; j < Q; ++j) {
6615
6615
const half m = M[j];
6616
6616
6617
- for (int64_t p = lane_id; p < C; p += NW) {
6617
+ for (int p = lane_id; p < C; p += NW) {
6618
6618
const half s = ss[j*T + p];
6619
6619
6620
6620
smax = __hmax (smax, s);
@@ -6633,7 +6633,7 @@ static __global__ void flash_attn_ext_f16(
6633
6633
// local sum
6634
6634
half ls = 0 .0f ;
6635
6635
6636
- for (int64_t p = lane_id; p < C; p += NW) {
6636
+ for (int p = lane_id; p < C; p += NW) {
6637
6637
const half s = ss[j*T + p];
6638
6638
6639
6639
const half vs = __hisinf (s) == -1 ? __float2half (0 .0f ) : hexp (s - M[j]);
@@ -6656,13 +6656,13 @@ static __global__ void flash_attn_ext_f16(
6656
6656
}
6657
6657
6658
6658
// O = diag(ms)*O
6659
- for (int64_t j = 0 ; j < Q16; ++j) {
6659
+ for (int j = 0 ; j < Q16; ++j) {
6660
6660
half16x16_a mm;
6661
6661
half16x16_b lob;
6662
6662
6663
6663
nvcuda::wmma::load_matrix_sync (mm, ss + 16 *j*T + C + 16 *j, T);
6664
6664
6665
- for (int64_t i = 0 ; i < D16; ++i) {
6665
+ for (int i = 0 ; i < D16; ++i) {
6666
6666
// convert accumulator to matrix_b
6667
6667
nvcuda::wmma::store_matrix_sync ( ss + 16 *j*T + C + 16 *j, lo[j][i], T, nvcuda::wmma::mem_row_major);
6668
6668
nvcuda::wmma::load_matrix_sync (lob, ss + 16 *j*T + C + 16 *j, T);
@@ -6680,17 +6680,17 @@ static __global__ void flash_attn_ext_f16(
6680
6680
const half * pv = (const half *) ((const char *) v + ((ic + 16 *cc)*nb21 + iv2*nb22 + iv3*nb23));
6681
6681
6682
6682
half16x16_b mk[D16];
6683
- for (int64_t i = 0 ; i < D16; ++i) {
6683
+ for (int i = 0 ; i < D16; ++i) {
6684
6684
nvcuda::wmma::load_matrix_sync (mk[i], pv + i*16 , nb21/sizeof (half));
6685
6685
}
6686
6686
6687
6687
half16x16_a mv[Q16];
6688
- for (int64_t j = 0 ; j < Q16; ++j) {
6688
+ for (int j = 0 ; j < Q16; ++j) {
6689
6689
nvcuda::wmma::load_matrix_sync (mv[j], ss + 16 *j*T + 16 *cc, T);
6690
6690
}
6691
6691
6692
- for (int64_t j = 0 ; j < Q16; ++j) {
6693
- for (int64_t i = 0 ; i < D16; ++i) {
6692
+ for (int j = 0 ; j < Q16; ++j) {
6693
+ for (int i = 0 ; i < D16; ++i) {
6694
6694
nvcuda::wmma::mma_sync (lo[j][i], mv[j], mk[i], lo[j][i]);
6695
6695
}
6696
6696
}
@@ -6699,7 +6699,7 @@ static __global__ void flash_attn_ext_f16(
6699
6699
}
6700
6700
6701
6701
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
6702
- for (int64_t j = 0 ; j < Q; ++j) {
6702
+ for (int j = 0 ; j < Q; ++j) {
6703
6703
if (lane_id == 0 ) {
6704
6704
ss[j*T + 0 ] = S[j];
6705
6705
ss[j*T + 1 ] = M[j];
@@ -6708,16 +6708,16 @@ static __global__ void flash_attn_ext_f16(
6708
6708
}
6709
6709
6710
6710
// reduce the warps sequentially
6711
- for (int64_t sg = 1 ; sg < num_warps; ++sg) {
6711
+ for (int sg = 1 ; sg < num_warps; ++sg) {
6712
6712
half S = __float2half (0 .0f );
6713
6713
half M = __float2half (-INFINITY);
6714
6714
6715
6715
__syncthreads ();
6716
6716
6717
6717
// each simdgroup stores its output to shared memory, reusing sq
6718
6718
if (warp_id == sg) {
6719
- for (int64_t j = 0 ; j < Q16; ++j) {
6720
- for (int64_t i = 0 ; i < D16; ++i) {
6719
+ for (int j = 0 ; j < Q16; ++j) {
6720
+ for (int i = 0 ; i < D16; ++i) {
6721
6721
nvcuda::wmma::store_matrix_sync (sq + 16 *j*T + i*16 , lo[j][i], T, nvcuda::wmma::mem_row_major);
6722
6722
}
6723
6723
}
@@ -6727,7 +6727,7 @@ static __global__ void flash_attn_ext_f16(
6727
6727
6728
6728
// the first simdgroup accumulates the results from the other simdgroups
6729
6729
if (warp_id == 0 ) {
6730
- for (int64_t j = 0 ; j < Q; ++j) {
6730
+ for (int j = 0 ; j < Q; ++j) {
6731
6731
const half S0 = ss[j*T + 0 ];
6732
6732
const half S1 = ss[j*T + sg*SH + 0 ];
6733
6733
@@ -6751,7 +6751,7 @@ static __global__ void flash_attn_ext_f16(
6751
6751
}
6752
6752
6753
6753
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
6754
- for (int64_t j = 0 ; j < Q16; ++j) {
6754
+ for (int j = 0 ; j < Q16; ++j) {
6755
6755
half16x16_a ms0;
6756
6756
half16x16_a ms1;
6757
6757
half16x16_b t;
@@ -6760,7 +6760,7 @@ static __global__ void flash_attn_ext_f16(
6760
6760
nvcuda::wmma::load_matrix_sync (ms0, ss + 16 *j*T + C + 16 *j, T);
6761
6761
nvcuda::wmma::load_matrix_sync (ms1, ss + 16 *j*T + C + 16 *j + sg*SH, T);
6762
6762
6763
- for (int64_t i = 0 ; i < D16; ++i) {
6763
+ for (int i = 0 ; i < D16; ++i) {
6764
6764
nvcuda::wmma::load_matrix_sync (t, sq + 16 *j*T + i*16 , T);
6765
6765
nvcuda::wmma::mma_sync (t2, ms1, t, zr);
6766
6766
@@ -6776,19 +6776,19 @@ static __global__ void flash_attn_ext_f16(
6776
6776
6777
6777
// store result to shared memory (reuse sq)
6778
6778
if (warp_id == 0 ) {
6779
- for (int64_t j = 0 ; j < Q16; ++j) {
6780
- for (int64_t i = 0 ; i < D16; ++i) {
6779
+ for (int j = 0 ; j < Q16; ++j) {
6780
+ for (int i = 0 ; i < D16; ++i) {
6781
6781
nvcuda::wmma::store_matrix_sync (sq + 16 *j*T + i*16 , lo[j][i], T, nvcuda::wmma::mem_row_major);
6782
6782
}
6783
6783
}
6784
6784
}
6785
6785
6786
6786
// final rescale with 1/S and store to global memory
6787
6787
if (warp_id == 0 ) {
6788
- for (int64_t j = 0 ; j < Q && iq1 + j < ne01; ++j) {
6788
+ for (int j = 0 ; j < Q && iq1 + j < ne01; ++j) {
6789
6789
const half S = ss[j*T + 0 ];
6790
6790
6791
- for (int64_t i = lane_id; i < D; i += NW) {
6791
+ for (int i = lane_id; i < D; i += NW) {
6792
6792
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float (sq[j*T + i] / S);
6793
6793
}
6794
6794
}
0 commit comments