Skip to content

Commit 38c0347

Browse files
CUDA: fix FA out-of-bounds writes (#7465)
1 parent b18532a commit 38c0347

File tree

4 files changed

+18
-2
lines changed

4 files changed

+18
-2
lines changed

ggml-cuda/fattn-tile-f16.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,10 @@ static __global__ void flash_attn_tile_ext_f16(
238238
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
239239
const int j_VKQ = j_VKQ_0 + threadIdx.y;
240240

241+
if (ic0 + j_VKQ >= ne01) {
242+
return;
243+
}
244+
241245
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
242246
kqsum_j = warp_reduce_sum(kqsum_j);
243247

ggml-cuda/fattn-tile-f32.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,10 @@ static __global__ void flash_attn_tile_ext_f32(
237237
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
238238
const int j_VKQ = j_VKQ_0 + threadIdx.y;
239239

240+
if (ic0 + j_VKQ >= ne01) {
241+
return;
242+
}
243+
240244
float kqsum_j = kqsum[j_VKQ_0/nwarps];
241245
kqsum_j = warp_reduce_sum(kqsum_j);
242246

ggml-cuda/fattn-vec-f16.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,10 @@ static __global__ void flash_attn_vec_ext_f16(
212212

213213
#pragma unroll
214214
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
215+
if (ic0 + j_VKQ >= ne01) {
216+
break;
217+
}
218+
215219
kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
216220
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
217221

@@ -223,7 +227,7 @@ static __global__ void flash_attn_vec_ext_f16(
223227
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
224228
}
225229

226-
if (parallel_blocks != 1 && tid < ncols) {
230+
if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) {
227231
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
228232
}
229233
#else

ggml-cuda/fattn-vec-f32.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,10 @@ static __global__ void flash_attn_vec_ext_f32(
200200

201201
#pragma unroll
202202
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
203+
if (ic0 + j_VKQ >= ne01) {
204+
break;
205+
}
206+
203207
kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
204208
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
205209

@@ -211,7 +215,7 @@ static __global__ void flash_attn_vec_ext_f32(
211215
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
212216
}
213217

214-
if (parallel_blocks != 1 && tid < ncols) {
218+
if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) {
215219
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
216220
}
217221
}

0 commit comments

Comments
 (0)