Skip to content

Commit 12eaa22

Browse files
committed
tests : update dims
1 parent db1f3c4 commit 12eaa22

File tree

2 files changed

+110
-76
lines changed

2 files changed

+110
-76
lines changed

ggml-cuda.cu

Lines changed: 107 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -6568,7 +6568,8 @@ static __global__ void flash_attn_ext_f16(
65686568
for (int64_t j = 0; j < Q16; ++j) {
65696569
half16x16_a mqka;
65706570
half16x16_acc mm;
6571-
if(mp) {
6571+
6572+
if (mp) {
65726573
nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major);
65736574
}
65746575

@@ -10927,78 +10928,111 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
1092710928

1092810929
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
1092910930

10930-
switch (Q->ne[0])
10931-
{
10932-
case 16:
10933-
flash_attn_ext_f16<16, NQPB, NCPW>
10934-
<<<blocks_num, block_dim, shmem, main_stream>>> (
10935-
(const char *) src0_extra->data_device[g_main_device], // Query
10936-
(const char *) src1_extra->data_device[g_main_device], // Key
10937-
(const char *) src2_extra->data_device[g_main_device], // Value
10938-
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
10939-
(float *) dst_extra->data_device[g_main_device], // dst
10940-
scale,
10941-
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
10942-
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
10943-
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
10944-
Q->nb[1], Q->nb[2], Q->nb[3],
10945-
K->nb[1], K->nb[2], K->nb[3],
10946-
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
10947-
);
10948-
break;
10949-
case 64:
10950-
flash_attn_ext_f16<64, NQPB, NCPW>
10951-
<<<blocks_num, block_dim, shmem, main_stream>>> (
10952-
(const char *) src0_extra->data_device[g_main_device], // Query
10953-
(const char *) src1_extra->data_device[g_main_device], // Key
10954-
(const char *) src2_extra->data_device[g_main_device], // Value
10955-
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
10956-
(float *) dst_extra->data_device[g_main_device], // dst
10957-
scale,
10958-
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
10959-
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
10960-
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
10961-
Q->nb[1], Q->nb[2], Q->nb[3],
10962-
K->nb[1], K->nb[2], K->nb[3],
10963-
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
10964-
);
10965-
break;
10966-
case 80:
10967-
flash_attn_ext_f16<80, NQPB, NCPW>
10968-
<<<blocks_num, block_dim, shmem, main_stream>>> (
10969-
(const char *) src0_extra->data_device[g_main_device], // Query
10970-
(const char *) src1_extra->data_device[g_main_device], // Key
10971-
(const char *) src2_extra->data_device[g_main_device], // Value
10972-
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
10973-
(float *) dst_extra->data_device[g_main_device], // dst
10974-
scale,
10975-
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
10976-
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
10977-
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
10978-
Q->nb[1], Q->nb[2], Q->nb[3],
10979-
K->nb[1], K->nb[2], K->nb[3],
10980-
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
10981-
);
10982-
break;
10983-
case 128:
10984-
flash_attn_ext_f16<128, NQPB, NCPW>
10985-
<<<blocks_num, block_dim, shmem, main_stream>>> (
10986-
(const char *) src0_extra->data_device[g_main_device], // Query
10987-
(const char *) src1_extra->data_device[g_main_device], // Key
10988-
(const char *) src2_extra->data_device[g_main_device], // Value
10989-
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
10990-
(float *) dst_extra->data_device[g_main_device], // dst
10991-
scale,
10992-
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
10993-
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
10994-
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
10995-
Q->nb[1], Q->nb[2], Q->nb[3],
10996-
K->nb[1], K->nb[2], K->nb[3],
10997-
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
10998-
);
10999-
break;
11000-
default:
11001-
break;
10931+
switch (Q->ne[0]) {
10932+
case 64:
10933+
flash_attn_ext_f16<64, NQPB, NCPW>
10934+
<<<blocks_num, block_dim, shmem, main_stream>>> (
10935+
(const char *) src0_extra->data_device[g_main_device], // Query
10936+
(const char *) src1_extra->data_device[g_main_device], // Key
10937+
(const char *) src2_extra->data_device[g_main_device], // Value
10938+
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
10939+
(float *) dst_extra->data_device[g_main_device], // dst
10940+
scale,
10941+
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
10942+
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
10943+
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
10944+
Q->nb[1], Q->nb[2], Q->nb[3],
10945+
K->nb[1], K->nb[2], K->nb[3],
10946+
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
10947+
);
10948+
break;
10949+
case 80:
10950+
flash_attn_ext_f16<80, NQPB, NCPW>
10951+
<<<blocks_num, block_dim, shmem, main_stream>>> (
10952+
(const char *) src0_extra->data_device[g_main_device], // Query
10953+
(const char *) src1_extra->data_device[g_main_device], // Key
10954+
(const char *) src2_extra->data_device[g_main_device], // Value
10955+
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
10956+
(float *) dst_extra->data_device[g_main_device], // dst
10957+
scale,
10958+
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
10959+
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
10960+
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
10961+
Q->nb[1], Q->nb[2], Q->nb[3],
10962+
K->nb[1], K->nb[2], K->nb[3],
10963+
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
10964+
);
10965+
break;
10966+
case 96:
10967+
flash_attn_ext_f16<96, NQPB, NCPW>
10968+
<<<blocks_num, block_dim, shmem, main_stream>>> (
10969+
(const char *) src0_extra->data_device[g_main_device], // Query
10970+
(const char *) src1_extra->data_device[g_main_device], // Key
10971+
(const char *) src2_extra->data_device[g_main_device], // Value
10972+
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
10973+
(float *) dst_extra->data_device[g_main_device], // dst
10974+
scale,
10975+
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
10976+
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
10977+
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
10978+
Q->nb[1], Q->nb[2], Q->nb[3],
10979+
K->nb[1], K->nb[2], K->nb[3],
10980+
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
10981+
);
10982+
break;
10983+
case 112:
10984+
flash_attn_ext_f16<112, NQPB, NCPW>
10985+
<<<blocks_num, block_dim, shmem, main_stream>>> (
10986+
(const char *) src0_extra->data_device[g_main_device], // Query
10987+
(const char *) src1_extra->data_device[g_main_device], // Key
10988+
(const char *) src2_extra->data_device[g_main_device], // Value
10989+
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
10990+
(float *) dst_extra->data_device[g_main_device], // dst
10991+
scale,
10992+
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
10993+
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
10994+
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
10995+
Q->nb[1], Q->nb[2], Q->nb[3],
10996+
K->nb[1], K->nb[2], K->nb[3],
10997+
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
10998+
);
10999+
break;
11000+
case 128:
11001+
flash_attn_ext_f16<128, NQPB, NCPW>
11002+
<<<blocks_num, block_dim, shmem, main_stream>>> (
11003+
(const char *) src0_extra->data_device[g_main_device], // Query
11004+
(const char *) src1_extra->data_device[g_main_device], // Key
11005+
(const char *) src2_extra->data_device[g_main_device], // Value
11006+
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
11007+
(float *) dst_extra->data_device[g_main_device], // dst
11008+
scale,
11009+
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
11010+
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
11011+
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
11012+
Q->nb[1], Q->nb[2], Q->nb[3],
11013+
K->nb[1], K->nb[2], K->nb[3],
11014+
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
11015+
);
11016+
break;
11017+
case 256:
11018+
flash_attn_ext_f16<256, NQPB, NCPW>
11019+
<<<blocks_num, block_dim, shmem, main_stream>>> (
11020+
(const char *) src0_extra->data_device[g_main_device], // Query
11021+
(const char *) src1_extra->data_device[g_main_device], // Key
11022+
(const char *) src2_extra->data_device[g_main_device], // Value
11023+
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
11024+
(float *) dst_extra->data_device[g_main_device], // dst
11025+
scale,
11026+
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
11027+
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
11028+
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
11029+
Q->nb[1], Q->nb[2], Q->nb[3],
11030+
K->nb[1], K->nb[2], K->nb[3],
11031+
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
11032+
);
11033+
break;
11034+
default:
11035+
break;
1100211036
}
1100311037
}
1100411038

tests/test-backend-ops.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ struct test_case {
572572
// duplicate the op
573573
size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU
574574
int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1;
575-
#if 1
575+
#if 0
576576
for (int i = 1; i < n_runs; i++) {
577577
gf->nodes[gf->n_nodes++] = out;
578578
}
@@ -2209,8 +2209,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
22092209
test_cases.emplace_back(new test_pad());
22102210
test_cases.emplace_back(new test_leaky_relu());
22112211

2212-
#if 0
2213-
for (int hs : { 64, 80, 96, 112, 128, 256, }) {
2212+
#if 1
2213+
for (int hs : { 64, 80, 128, }) {
22142214
for (int nh : { 32, }) {
22152215
for (int kv : { 512, 1024, 2048, 4096, }) {
22162216
for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) {

0 commit comments

Comments
 (0)