@@ -6462,10 +6462,20 @@ static __global__ void flash_attn_ext_f16(
6462
6462
half16x16_acc lo[Q16][D16];
6463
6463
6464
6464
// load heads from Q to shared memory
6465
- for (int j = warp_id; j < Q; j += num_warps) {
6465
+ for (int j0 = 0 ; j0 < Q; j0 += num_warps) {
6466
+ const int j = j0 + warp_id;
6467
+ if (j >= Q) {
6468
+ break ;
6469
+ }
6470
+
6466
6471
const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
6467
6472
6468
- for (int i = lane_id; i < D2; i += NW) {
6473
+ for (int i0 = 0 ; i0 < D2; i0 += NW) {
6474
+ const int i = i0 + lane_id;
6475
+ if (i >= D2) {
6476
+ break ;
6477
+ }
6478
+
6469
6479
if (iq1 + j < ne01) {
6470
6480
sq2[j*T2 + i] = __float22half2_rn (q2[i]);
6471
6481
} else {
@@ -6485,7 +6495,12 @@ static __global__ void flash_attn_ext_f16(
6485
6495
6486
6496
// zero out shared memory SH
6487
6497
for (int j = 0 ; j < Q; ++j) {
6488
- for (int i = lane_id; i < SH; i += NW) {
6498
+ for (int i0 = 0 ; i0 < SH; i0 += NW) {
6499
+ const int i = i0 + lane_id;
6500
+ if (i >= SH) {
6501
+ break ;
6502
+ }
6503
+
6489
6504
ss[j*T + i] = 0.0 ;
6490
6505
}
6491
6506
}
@@ -6544,7 +6559,12 @@ static __global__ void flash_attn_ext_f16(
6544
6559
6545
6560
// loop over the KV cache
6546
6561
// each simdgroup handles blocks of Q rows and C columns
6547
- for (int ic = C*warp_id; ic < ne11; ic += C*num_warps) {
6562
+ for (int ic0 = 0 ; ic0 < ne11; ic0 += C*num_warps) {
6563
+ const int ic = ic0 + warp_id*C;
6564
+ if (ic >= ne11) {
6565
+ break ;
6566
+ }
6567
+
6548
6568
// Q*K^T
6549
6569
{
6550
6570
for (int cc = 0 ; cc < C/16 ; ++cc) {
@@ -6614,7 +6634,9 @@ static __global__ void flash_attn_ext_f16(
6614
6634
for (int j = 0 ; j < Q; ++j) {
6615
6635
const half m = M[j];
6616
6636
6617
- for (int p = lane_id; p < C; p += NW) {
6637
+ for (int p0 = 0 ; p0 < C; p0 += NW) {
6638
+ const int p = p0 + lane_id;
6639
+
6618
6640
const half s = ss[j*T + p];
6619
6641
6620
6642
smax = __hmax (smax, s);
@@ -6633,7 +6655,9 @@ static __global__ void flash_attn_ext_f16(
6633
6655
// local sum
6634
6656
half ls = 0 .0f ;
6635
6657
6636
- for (int p = lane_id; p < C; p += NW) {
6658
+ for (int p0 = 0 ; p0 < C; p0 += NW) {
6659
+ const int p = p0 + lane_id;
6660
+
6637
6661
const half s = ss[j*T + p];
6638
6662
6639
6663
const half vs = __hisinf (s) == -1 ? __float2half (0 .0f ) : hexp (s - M[j]);
@@ -6788,7 +6812,12 @@ static __global__ void flash_attn_ext_f16(
6788
6812
for (int j = 0 ; j < Q && iq1 + j < ne01; ++j) {
6789
6813
const half S = ss[j*T + 0 ];
6790
6814
6791
- for (int i = lane_id; i < D; i += NW) {
6815
+ for (int i0 = 0 ; i0 < D; i0 += NW) {
6816
+ const int i = i0 + lane_id;
6817
+ if (i >= D) {
6818
+ break ;
6819
+ }
6820
+
6792
6821
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float (sq[j*T + i] / S);
6793
6822
}
6794
6823
}
0 commit comments