Skip to content

Commit 5c7e9c4

Browse files
fix commented-out kernel variants
1 parent a5436a0 commit 5c7e9c4

File tree

1 file changed

+28
-27
lines changed

1 file changed

+28
-27
lines changed

ggml-cuda/fattn-vec-f16.cu

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -382,37 +382,38 @@ void launch_fattn_vec_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
382382

383383
void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
384384
const ggml_tensor * KQV = dst;
385+
const ggml_tensor * Q = dst->src[0];
385386

386387
const int32_t precision = KQV->op_params[2];
387388
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
388389

389-
// if (Q->ne[1] == 1) {
390-
// constexpr int cols_per_block = 1;
391-
// constexpr int parallel_blocks = 4;
392-
// launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
393-
// return;
394-
// }
395-
396-
// if (Q->ne[1] == 2) {
397-
// constexpr int cols_per_block = 2;
398-
// constexpr int parallel_blocks = 4;
399-
// launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
400-
// return;
401-
// }
402-
403-
// if (Q->ne[1] <= 4) {
404-
// constexpr int cols_per_block = 4;
405-
// constexpr int parallel_blocks = 4;
406-
// launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
407-
// return;
408-
// }
409-
410-
// if (Q->ne[1] <= 8) {
411-
// constexpr int cols_per_block = 8;
412-
// constexpr int parallel_blocks = 4;
413-
// launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
414-
// return;
415-
// }
390+
if (Q->ne[1] == 1) {
391+
constexpr int cols_per_block = 1;
392+
constexpr int parallel_blocks = 4;
393+
launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
394+
return;
395+
}
396+
397+
if (Q->ne[1] == 2) {
398+
constexpr int cols_per_block = 2;
399+
constexpr int parallel_blocks = 4;
400+
launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
401+
return;
402+
}
403+
404+
if (Q->ne[1] <= 4) {
405+
constexpr int cols_per_block = 4;
406+
constexpr int parallel_blocks = 4;
407+
launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
408+
return;
409+
}
410+
411+
if (Q->ne[1] <= 8) {
412+
constexpr int cols_per_block = 8;
413+
constexpr int parallel_blocks = 4;
414+
launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
415+
return;
416+
}
416417

417418
constexpr int cols_per_block = 8;
418419
constexpr int parallel_blocks = 1;

0 commit comments

Comments
 (0)