File tree Expand file tree Collapse file tree 4 files changed +18
-2
lines changed Expand file tree Collapse file tree 4 files changed +18
-2
lines changed Original file line number Diff line number Diff line change @@ -238,6 +238,10 @@ static __global__ void flash_attn_tile_ext_f16(
238
238
for (int j_VKQ_0 = 0 ; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
239
239
const int j_VKQ = j_VKQ_0 + threadIdx .y ;
240
240
241
+ if (ic0 + j_VKQ >= ne01) {
242
+ return ;
243
+ }
244
+
241
245
half kqsum_j = __low2half (kqsum[j_VKQ_0/nwarps]) + __high2half (kqsum[j_VKQ_0/nwarps]);
242
246
kqsum_j = warp_reduce_sum (kqsum_j);
243
247
Original file line number Diff line number Diff line change @@ -237,6 +237,10 @@ static __global__ void flash_attn_tile_ext_f32(
237
237
for (int j_VKQ_0 = 0 ; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
238
238
const int j_VKQ = j_VKQ_0 + threadIdx .y ;
239
239
240
+ if (ic0 + j_VKQ >= ne01) {
241
+ return ;
242
+ }
243
+
240
244
float kqsum_j = kqsum[j_VKQ_0/nwarps];
241
245
kqsum_j = warp_reduce_sum (kqsum_j);
242
246
Original file line number Diff line number Diff line change @@ -212,6 +212,10 @@ static __global__ void flash_attn_vec_ext_f16(
212
212
213
213
#pragma unroll
214
214
for (int j_VKQ = 0 ; j_VKQ < ncols; ++j_VKQ) {
215
+ if (ic0 + j_VKQ >= ne01) {
216
+ break ;
217
+ }
218
+
215
219
kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx .x ];
216
220
kqsum[j_VKQ] = warp_reduce_sum (kqsum[j_VKQ]);
217
221
@@ -223,7 +227,7 @@ static __global__ void flash_attn_vec_ext_f16(
223
227
dst[j_dst*D*gridDim .y + D*blockIdx .y + tid] = dst_val;
224
228
}
225
229
226
- if (parallel_blocks != 1 && tid < ncols) {
230
+ if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01 ) {
227
231
dst_meta[(ic0 + tid)*gridDim .y *parallel_blocks + blockIdx .y *parallel_blocks + ip] = make_float2 (kqmax[tid], kqsum[tid]);
228
232
}
229
233
#else
Original file line number Diff line number Diff line change @@ -200,6 +200,10 @@ static __global__ void flash_attn_vec_ext_f32(
200
200
201
201
#pragma unroll
202
202
for (int j_VKQ = 0 ; j_VKQ < ncols; ++j_VKQ) {
203
+ if (ic0 + j_VKQ >= ne01) {
204
+ break ;
205
+ }
206
+
203
207
kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx .x ];
204
208
kqsum[j_VKQ] = warp_reduce_sum (kqsum[j_VKQ]);
205
209
@@ -211,7 +215,7 @@ static __global__ void flash_attn_vec_ext_f32(
211
215
dst[j_dst*D*gridDim .y + D*blockIdx .y + tid] = dst_val;
212
216
}
213
217
214
- if (parallel_blocks != 1 && tid < ncols) {
218
+ if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01 ) {
215
219
dst_meta[(ic0 + tid)*gridDim .y *parallel_blocks + blockIdx .y *parallel_blocks + ip] = make_float2 (kqmax[tid], kqsum[tid]);
216
220
}
217
221
}
You can’t perform that action at this time.
0 commit comments