Skip to content

Commit f4003cf

Browse files
fix nwarps > batch size
1 parent f087760 commit f4003cf

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

ggml-cuda/fattn-vec-f16.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ static __global__ void flash_attn_vec_ext_f16(
9292
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
9393
const int j = j0 + threadIdx.y;
9494

95+
if (j0 + nwarps > ncols && j >= ncols) {
96+
break;
97+
}
98+
9599
// Reuse KQ as temporary storage for converting Q to q8_1:
96100
int * tmp_q_i32 = (int *) &KQ[j*D];
97101
half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int));

ggml-cuda/fattn-vec-f32.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ static __global__ void flash_attn_vec_ext_f32(
9292
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
9393
const int j = j0 + threadIdx.y;
9494

95+
if (j0 + nwarps > ncols && j >= ncols) {
96+
break;
97+
}
98+
9599
// Reuse KQ as temporary storage for converting Q to q8_1:
96100
int * tmp_q_i32 = (int *) &KQ[j*D];
97101
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));

0 commit comments

Comments
 (0)