File tree Expand file tree Collapse file tree 2 files changed +8
-0
lines changed Expand file tree Collapse file tree 2 files changed +8
-0
lines changed Original file line number Diff line number Diff line change @@ -92,6 +92,10 @@ static __global__ void flash_attn_vec_ext_f16(
92
92
for (int j0 = 0 ; j0 < ncols; j0 += nwarps) {
93
93
const int j = j0 + threadIdx .y ;
94
94
95
+ if (j0 + nwarps > ncols && j >= ncols) {
96
+ break ;
97
+ }
98
+
95
99
// Reuse KQ as temporary storage for converting Q to q8_1:
96
100
int * tmp_q_i32 = (int *) &KQ[j*D];
97
101
half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof (int ));
Original file line number Diff line number Diff line change @@ -92,6 +92,10 @@ static __global__ void flash_attn_vec_ext_f32(
92
92
for (int j0 = 0 ; j0 < ncols; j0 += nwarps) {
93
93
const int j = j0 + threadIdx .y ;
94
94
95
+ if (j0 + nwarps > ncols && j >= ncols) {
96
+ break ;
97
+ }
98
+
95
99
// Reuse KQ as temporary storage for converting Q to q8_1:
96
100
int * tmp_q_i32 = (int *) &KQ[j*D];
97
101
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof (int ));
You can’t perform that action at this time.
0 commit comments