@@ -681,13 +681,15 @@ struct test_case {
681
681
682
682
// run
683
683
int64_t total_time_us = 0 ;
684
+ int64_t total_mem = 0 ;
684
685
int total_runs = 0 ;
685
686
do {
686
687
int64_t start_time = ggml_time_us ();
687
688
ggml_backend_graph_compute (backend, gf);
688
689
int64_t end_time = ggml_time_us ();
689
690
690
691
total_time_us += end_time - start_time;
692
+ total_mem += mem;
691
693
total_runs += n_runs;
692
694
} while (total_time_us < 1000 *1000 ); // run for at least 1 second
693
695
@@ -717,7 +719,7 @@ struct test_case {
717
719
} else {
718
720
printf (" %8zu kB/run - \033 [1;34m%7.2f GB/s\033 [0m" ,
719
721
op_size (out) / 1024 ,
720
- mem / (total_time_us / 1e6 ) / 1024.0 / 1024.0 / 1024.0 );
722
+ total_mem / (total_time_us / 1e6 ) / 1024.0 / 1024.0 / 1024.0 );
721
723
}
722
724
printf (" \n " );
723
725
@@ -2740,6 +2742,13 @@ struct test_flash_attn_ext : public test_case {
2740
2742
return 5e-4 ;
2741
2743
}
2742
2744
2745
+ uint64_t op_flops (ggml_tensor * t) override {
2746
+ GGML_UNUSED (t);
2747
+ // Just counting matmul costs:
2748
+ // Q*K^T is nb x hs x kv, P*V is nb x kv x hs, per head
2749
+ return 2 * 2 * nh * nb * hs * kv;
2750
+ }
2751
+
2743
2752
test_flash_attn_ext (int64_t hs = 128 , int64_t nh = 32 , int64_t kv = 96 , int64_t nb = 8 ,
2744
2753
bool mask = true , float max_bias = 0 .0f , float logit_softcap = 0 .0f , ggml_type type_KV = GGML_TYPE_F16)
2745
2754
: hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV) {}
@@ -3779,6 +3788,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
3779
3788
test_cases.emplace_back (new test_bin_bcast (ggml_add, GGML_TYPE_F32, {4096 , 1 , 1 , 1 }, {1 , 1 , 1 , 1 }));
3780
3789
test_cases.emplace_back (new test_bin_bcast (ggml_add, GGML_TYPE_F32, {4096 , 1 , 1 , 1 }, {1 , 512 , 1 , 1 }));
3781
3790
3791
+ test_cases.emplace_back (new test_cpy (GGML_TYPE_F32, GGML_TYPE_F16, {512 , 3072 , 1 , 1 }));
3792
+
3782
3793
for (int bs : {1 , 512 }) {
3783
3794
for (ggml_type type_a : all_types) {
3784
3795
for (ggml_type type_b : {GGML_TYPE_F32}) {
0 commit comments