@@ -382,37 +382,38 @@ void launch_fattn_vec_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
382
382
383
383
void ggml_cuda_flash_attn_ext_vec_f16_no_mma (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
384
384
const ggml_tensor * KQV = dst;
385
+ const ggml_tensor * Q = dst->src [0 ];
385
386
386
387
const int32_t precision = KQV->op_params [2 ];
387
388
GGML_ASSERT (precision == GGML_PREC_DEFAULT);
388
389
389
- // if (Q->ne[1] == 1) {
390
- // constexpr int cols_per_block = 1;
391
- // constexpr int parallel_blocks = 4;
392
- // launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
393
- // return;
394
- // }
395
-
396
- // if (Q->ne[1] == 2) {
397
- // constexpr int cols_per_block = 2;
398
- // constexpr int parallel_blocks = 4;
399
- // launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
400
- // return;
401
- // }
402
-
403
- // if (Q->ne[1] <= 4) {
404
- // constexpr int cols_per_block = 4;
405
- // constexpr int parallel_blocks = 4;
406
- // launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
407
- // return;
408
- // }
409
-
410
- // if (Q->ne[1] <= 8) {
411
- // constexpr int cols_per_block = 8;
412
- // constexpr int parallel_blocks = 4;
413
- // launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
414
- // return;
415
- // }
390
+ if (Q->ne [1 ] == 1 ) {
391
+ constexpr int cols_per_block = 1 ;
392
+ constexpr int parallel_blocks = 4 ;
393
+ launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
394
+ return ;
395
+ }
396
+
397
+ if (Q->ne [1 ] == 2 ) {
398
+ constexpr int cols_per_block = 2 ;
399
+ constexpr int parallel_blocks = 4 ;
400
+ launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
401
+ return ;
402
+ }
403
+
404
+ if (Q->ne [1 ] <= 4 ) {
405
+ constexpr int cols_per_block = 4 ;
406
+ constexpr int parallel_blocks = 4 ;
407
+ launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
408
+ return ;
409
+ }
410
+
411
+ if (Q->ne [1 ] <= 8 ) {
412
+ constexpr int cols_per_block = 8 ;
413
+ constexpr int parallel_blocks = 4 ;
414
+ launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
415
+ return ;
416
+ }
416
417
417
418
constexpr int cols_per_block = 8 ;
418
419
constexpr int parallel_blocks = 1 ;
0 commit comments