@@ -6568,7 +6568,8 @@ static __global__ void flash_attn_ext_f16(
6568
6568
for (int64_t j = 0 ; j < Q16; ++j) {
6569
6569
half16x16_a mqka;
6570
6570
half16x16_acc mm;
6571
- if (mp) {
6571
+
6572
+ if (mp) {
6572
6573
nvcuda::wmma::load_matrix_sync (mm, mp + 16 *j*(nb31/sizeof (half)) + ic + 16 *cc, nb31/sizeof (half), nvcuda::wmma::mem_row_major);
6573
6574
}
6574
6575
@@ -10927,78 +10928,111 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
10927
10928
10928
10929
const size_t shmem = nqpb*(Q->ne [0 ] + nwarps*(ncpw + nqpb))*(sizeof (float )/2 );
10929
10930
10930
- switch (Q->ne [0 ])
10931
- {
10932
- case 16 :
10933
- flash_attn_ext_f16<16 , NQPB, NCPW>
10934
- <<<blocks_num, block_dim, shmem, main_stream>>> (
10935
- (const char *) src0_extra->data_device [g_main_device], // Query
10936
- (const char *) src1_extra->data_device [g_main_device], // Key
10937
- (const char *) src2_extra->data_device [g_main_device], // Value
10938
- mask ? ((const char *) src3_extra->data_device [g_main_device]) : nullptr , // Mask
10939
- (float *) dst_extra->data_device [g_main_device], // dst
10940
- scale,
10941
- Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
10942
- K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
10943
- mask ? mask->ne [1 ] : 0 , mask ? mask->nb [1 ] : 0 ,
10944
- Q->nb [1 ], Q->nb [2 ], Q->nb [3 ],
10945
- K->nb [1 ], K->nb [2 ], K->nb [3 ],
10946
- KQV->ne [0 ], KQV->ne [1 ], KQV->ne [2 ], KQV->ne [3 ]
10947
- );
10948
- break ;
10949
- case 64 :
10950
- flash_attn_ext_f16<64 , NQPB, NCPW>
10951
- <<<blocks_num, block_dim, shmem, main_stream>>> (
10952
- (const char *) src0_extra->data_device [g_main_device], // Query
10953
- (const char *) src1_extra->data_device [g_main_device], // Key
10954
- (const char *) src2_extra->data_device [g_main_device], // Value
10955
- mask ? ((const char *) src3_extra->data_device [g_main_device]) : nullptr , // Mask
10956
- (float *) dst_extra->data_device [g_main_device], // dst
10957
- scale,
10958
- Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
10959
- K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
10960
- mask ? mask->ne [1 ] : 0 , mask ? mask->nb [1 ] : 0 ,
10961
- Q->nb [1 ], Q->nb [2 ], Q->nb [3 ],
10962
- K->nb [1 ], K->nb [2 ], K->nb [3 ],
10963
- KQV->ne [0 ], KQV->ne [1 ], KQV->ne [2 ], KQV->ne [3 ]
10964
- );
10965
- break ;
10966
- case 80 :
10967
- flash_attn_ext_f16<80 , NQPB, NCPW>
10968
- <<<blocks_num, block_dim, shmem, main_stream>>> (
10969
- (const char *) src0_extra->data_device [g_main_device], // Query
10970
- (const char *) src1_extra->data_device [g_main_device], // Key
10971
- (const char *) src2_extra->data_device [g_main_device], // Value
10972
- mask ? ((const char *) src3_extra->data_device [g_main_device]) : nullptr , // Mask
10973
- (float *) dst_extra->data_device [g_main_device], // dst
10974
- scale,
10975
- Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
10976
- K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
10977
- mask ? mask->ne [1 ] : 0 , mask ? mask->nb [1 ] : 0 ,
10978
- Q->nb [1 ], Q->nb [2 ], Q->nb [3 ],
10979
- K->nb [1 ], K->nb [2 ], K->nb [3 ],
10980
- KQV->ne [0 ], KQV->ne [1 ], KQV->ne [2 ], KQV->ne [3 ]
10981
- );
10982
- break ;
10983
- case 128 :
10984
- flash_attn_ext_f16<128 , NQPB, NCPW>
10985
- <<<blocks_num, block_dim, shmem, main_stream>>> (
10986
- (const char *) src0_extra->data_device [g_main_device], // Query
10987
- (const char *) src1_extra->data_device [g_main_device], // Key
10988
- (const char *) src2_extra->data_device [g_main_device], // Value
10989
- mask ? ((const char *) src3_extra->data_device [g_main_device]) : nullptr , // Mask
10990
- (float *) dst_extra->data_device [g_main_device], // dst
10991
- scale,
10992
- Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
10993
- K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
10994
- mask ? mask->ne [1 ] : 0 , mask ? mask->nb [1 ] : 0 ,
10995
- Q->nb [1 ], Q->nb [2 ], Q->nb [3 ],
10996
- K->nb [1 ], K->nb [2 ], K->nb [3 ],
10997
- KQV->ne [0 ], KQV->ne [1 ], KQV->ne [2 ], KQV->ne [3 ]
10998
- );
10999
- break ;
11000
- default :
11001
- break ;
10931
+ switch (Q->ne [0 ]) {
10932
+ case 64 :
10933
+ flash_attn_ext_f16<64 , NQPB, NCPW>
10934
+ <<<blocks_num, block_dim, shmem, main_stream>>> (
10935
+ (const char *) src0_extra->data_device [g_main_device], // Query
10936
+ (const char *) src1_extra->data_device [g_main_device], // Key
10937
+ (const char *) src2_extra->data_device [g_main_device], // Value
10938
+ mask ? ((const char *) src3_extra->data_device [g_main_device]) : nullptr , // Mask
10939
+ (float *) dst_extra->data_device [g_main_device], // dst
10940
+ scale,
10941
+ Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
10942
+ K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
10943
+ mask ? mask->ne [1 ] : 0 , mask ? mask->nb [1 ] : 0 ,
10944
+ Q->nb [1 ], Q->nb [2 ], Q->nb [3 ],
10945
+ K->nb [1 ], K->nb [2 ], K->nb [3 ],
10946
+ KQV->ne [0 ], KQV->ne [1 ], KQV->ne [2 ], KQV->ne [3 ]
10947
+ );
10948
+ break ;
10949
+ case 80 :
10950
+ flash_attn_ext_f16<80 , NQPB, NCPW>
10951
+ <<<blocks_num, block_dim, shmem, main_stream>>> (
10952
+ (const char *) src0_extra->data_device [g_main_device], // Query
10953
+ (const char *) src1_extra->data_device [g_main_device], // Key
10954
+ (const char *) src2_extra->data_device [g_main_device], // Value
10955
+ mask ? ((const char *) src3_extra->data_device [g_main_device]) : nullptr , // Mask
10956
+ (float *) dst_extra->data_device [g_main_device], // dst
10957
+ scale,
10958
+ Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
10959
+ K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
10960
+ mask ? mask->ne [1 ] : 0 , mask ? mask->nb [1 ] : 0 ,
10961
+ Q->nb [1 ], Q->nb [2 ], Q->nb [3 ],
10962
+ K->nb [1 ], K->nb [2 ], K->nb [3 ],
10963
+ KQV->ne [0 ], KQV->ne [1 ], KQV->ne [2 ], KQV->ne [3 ]
10964
+ );
10965
+ break ;
10966
+ case 96 :
10967
+ flash_attn_ext_f16<96 , NQPB, NCPW>
10968
+ <<<blocks_num, block_dim, shmem, main_stream>>> (
10969
+ (const char *) src0_extra->data_device [g_main_device], // Query
10970
+ (const char *) src1_extra->data_device [g_main_device], // Key
10971
+ (const char *) src2_extra->data_device [g_main_device], // Value
10972
+ mask ? ((const char *) src3_extra->data_device [g_main_device]) : nullptr , // Mask
10973
+ (float *) dst_extra->data_device [g_main_device], // dst
10974
+ scale,
10975
+ Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
10976
+ K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
10977
+ mask ? mask->ne [1 ] : 0 , mask ? mask->nb [1 ] : 0 ,
10978
+ Q->nb [1 ], Q->nb [2 ], Q->nb [3 ],
10979
+ K->nb [1 ], K->nb [2 ], K->nb [3 ],
10980
+ KQV->ne [0 ], KQV->ne [1 ], KQV->ne [2 ], KQV->ne [3 ]
10981
+ );
10982
+ break ;
10983
+ case 112 :
10984
+ flash_attn_ext_f16<112 , NQPB, NCPW>
10985
+ <<<blocks_num, block_dim, shmem, main_stream>>> (
10986
+ (const char *) src0_extra->data_device [g_main_device], // Query
10987
+ (const char *) src1_extra->data_device [g_main_device], // Key
10988
+ (const char *) src2_extra->data_device [g_main_device], // Value
10989
+ mask ? ((const char *) src3_extra->data_device [g_main_device]) : nullptr , // Mask
10990
+ (float *) dst_extra->data_device [g_main_device], // dst
10991
+ scale,
10992
+ Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
10993
+ K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
10994
+ mask ? mask->ne [1 ] : 0 , mask ? mask->nb [1 ] : 0 ,
10995
+ Q->nb [1 ], Q->nb [2 ], Q->nb [3 ],
10996
+ K->nb [1 ], K->nb [2 ], K->nb [3 ],
10997
+ KQV->ne [0 ], KQV->ne [1 ], KQV->ne [2 ], KQV->ne [3 ]
10998
+ );
10999
+ break ;
11000
+ case 128 :
11001
+ flash_attn_ext_f16<128 , NQPB, NCPW>
11002
+ <<<blocks_num, block_dim, shmem, main_stream>>> (
11003
+ (const char *) src0_extra->data_device [g_main_device], // Query
11004
+ (const char *) src1_extra->data_device [g_main_device], // Key
11005
+ (const char *) src2_extra->data_device [g_main_device], // Value
11006
+ mask ? ((const char *) src3_extra->data_device [g_main_device]) : nullptr , // Mask
11007
+ (float *) dst_extra->data_device [g_main_device], // dst
11008
+ scale,
11009
+ Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
11010
+ K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
11011
+ mask ? mask->ne [1 ] : 0 , mask ? mask->nb [1 ] : 0 ,
11012
+ Q->nb [1 ], Q->nb [2 ], Q->nb [3 ],
11013
+ K->nb [1 ], K->nb [2 ], K->nb [3 ],
11014
+ KQV->ne [0 ], KQV->ne [1 ], KQV->ne [2 ], KQV->ne [3 ]
11015
+ );
11016
+ break ;
11017
+ case 256 :
11018
+ flash_attn_ext_f16<256 , NQPB, NCPW>
11019
+ <<<blocks_num, block_dim, shmem, main_stream>>> (
11020
+ (const char *) src0_extra->data_device [g_main_device], // Query
11021
+ (const char *) src1_extra->data_device [g_main_device], // Key
11022
+ (const char *) src2_extra->data_device [g_main_device], // Value
11023
+ mask ? ((const char *) src3_extra->data_device [g_main_device]) : nullptr , // Mask
11024
+ (float *) dst_extra->data_device [g_main_device], // dst
11025
+ scale,
11026
+ Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], Q->ne [3 ],
11027
+ K->ne [0 ], K->ne [1 ], K->ne [2 ], K->ne [3 ],
11028
+ mask ? mask->ne [1 ] : 0 , mask ? mask->nb [1 ] : 0 ,
11029
+ Q->nb [1 ], Q->nb [2 ], Q->nb [3 ],
11030
+ K->nb [1 ], K->nb [2 ], K->nb [3 ],
11031
+ KQV->ne [0 ], KQV->ne [1 ], KQV->ne [2 ], KQV->ne [3 ]
11032
+ );
11033
+ break ;
11034
+ default :
11035
+ break ;
11002
11036
}
11003
11037
}
11004
11038
0 commit comments