15
15
#include < thread>
16
16
#include < vector>
17
17
18
+ // TODO: remove before merging
19
+ // #define TMP_ATTN_BENCH
20
+
18
21
static void init_tensor_uniform (ggml_tensor * tensor, float min = -1 .0f , float max = 1 .0f ) {
19
22
// static RNG initialization (revisit if n_threads stops being constant)
20
23
static const size_t n_threads = std::thread::hardware_concurrency ();
@@ -571,7 +574,7 @@ struct test_case {
571
574
// duplicate the op
572
575
size_t target_size = ggml_backend_is_cpu (backend) ? 1ULL << 33 : 1ULL << 35 ; // 8 GB CPU, 32 GB GPU
573
576
int n_runs = std::min ((size_t )gf->size - gf->n_nodes , target_size / op_size (out)) + 1 ;
574
- #if 0
577
+ #ifndef TMP_ATTN_BENCH
575
578
for (int i = 1 ; i < n_runs; i++) {
576
579
gf->nodes [gf->n_nodes ++] = out;
577
580
}
@@ -1513,8 +1516,8 @@ struct test_flash_attn_ext : public test_case {
1513
1516
}
1514
1517
};
1515
1518
1519
+ #ifdef TMP_ATTN_BENCH
1516
1520
// ATTN
1517
- // TODO: this is temporary until the FA branch is merged
1518
1521
struct test_attn : public test_case {
1519
1522
const int64_t hs; // head size
1520
1523
const int64_t nh; // num heads
@@ -1555,6 +1558,7 @@ struct test_attn : public test_case {
1555
1558
return cur;
1556
1559
}
1557
1560
};
1561
+ #endif
1558
1562
1559
1563
enum llm_norm_type {
1560
1564
LLM_NORM,
@@ -2220,7 +2224,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
2220
2224
test_cases.emplace_back (new test_timestep_embedding ());
2221
2225
test_cases.emplace_back (new test_leaky_relu ());
2222
2226
2223
- #if 1
2227
+ #ifdef TMP_ATTN_BENCH
2224
2228
for (int hs : { 128 , 256 , 64 , 80 , }) {
2225
2229
for (int nh : { 32 , }) {
2226
2230
for (int kv : { 512 , 1024 , 2048 , 4096 , }) {
@@ -2232,11 +2236,10 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
2232
2236
}
2233
2237
}
2234
2238
#else
2235
- for (int hs : { 128, }) {
2239
+ for (int hs : { 64 , 80 , 128 , 256 , }) {
2236
2240
for (int nh : { 32 , }) {
2237
2241
for (int kv : { 512 , 1024 , }) {
2238
- for (int nb : { 1, 2, 4, 8, 512 }) {
2239
- test_cases.emplace_back(new test_attn (hs, nh, kv, nb));
2242
+ for (int nb : { 1 , 2 , 4 , 8 , }) {
2240
2243
test_cases.emplace_back (new test_flash_attn_ext (hs, nh, kv, nb));
2241
2244
}
2242
2245
}
0 commit comments