|
10 | 10 | #include <executorch/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h>
|
11 | 11 | #include <iostream>
|
12 | 12 |
|
| 13 | +#include "stats.h" |
13 | 14 | #include "utils.h"
|
14 | 15 |
|
| 16 | +template <uint32_t NTap> |
| 17 | +struct DtJumpFinder { |
| 18 | + private: |
| 19 | + NTapAvgStats<double, NTap> time_avg_; |
| 20 | + AvgStats<double> dtime_avg_; |
| 21 | + double compensation_; |
| 22 | + double threshold_; |
| 23 | + |
| 24 | + public: |
| 25 | + // Compensation is a tiny additive to give on delta time so that the algorithm |
| 26 | + // works smoothly when a sequence of identical timing is ingested, which is |
| 27 | + // pretty common in our tests. Threshold is simply how many times the new |
| 28 | + // delta has to be to be recognized as a deviation. |
| 29 | + DtJumpFinder(double compensation = 0.01, double threshold = 10) |
| 30 | + : time_avg_(), |
| 31 | + dtime_avg_(), |
| 32 | + compensation_(compensation), |
| 33 | + threshold_(threshold) {} |
| 34 | + |
| 35 | + // Returns true if the delta time regarding to the last data point seems |
| 36 | + // normal; returns false if it seems the new data point is too much away from |
| 37 | + // the historical records. |
| 38 | + bool push(double time) { |
| 39 | + if (time_avg_.has_value()) { |
| 40 | + double dtime = std::abs(time - time_avg_) + (compensation_ * time_avg_); |
| 41 | + if (dtime_avg_.has_value()) { |
| 42 | + double ddtime = std::abs(dtime - dtime_avg_); |
| 43 | + std::cout << dtime << "\t" << dtime_avg_ << "\t" << ddtime << "\t"; |
| 44 | + if (ddtime > threshold_ * dtime_avg_) { |
| 45 | + return true; |
| 46 | + } |
| 47 | + } |
| 48 | + dtime_avg_.push(dtime); |
| 49 | + } |
| 50 | + time_avg_.push(time); |
| 51 | + return false; |
| 52 | + } |
| 53 | + |
| 54 | + double dtime_avg() const { |
| 55 | + return dtime_avg_; |
| 56 | + } |
| 57 | + double compensate_time() const { |
| 58 | + return compensation_ * time_avg_; |
| 59 | + } |
| 60 | +}; |
| 61 | + |
15 | 62 | void reg_count() {
|
16 | 63 | const uint32_t NREG_MIN = 1;
|
| 64 | + const uint32_t NREG_MAX = 512; |
| 65 | + const uint32_t NREG_STEP = 1; |
| 66 | + |
| 67 | + const double COMPENSATE = 0.01; |
| 68 | + const double THRESHOLD = 3; |
17 | 69 |
|
18 | 70 | uint32_t NITER;
|
19 | 71 |
|
@@ -43,6 +95,26 @@ void reg_count() {
|
43 | 95 | std::cout << "Calculating NITER..." << std::endl;
|
44 | 96 | ensure_min_niter(1000, NITER, [&]() { return bench(1, 1, NREG_MIN); });
|
45 | 97 | std::cout << "NITER: " << NITER << std::endl;
|
| 98 | + |
| 99 | + uint32_t nreg_max; |
| 100 | + |
| 101 | + DtJumpFinder<5> dj(COMPENSATE, THRESHOLD); |
| 102 | + uint32_t nreg = NREG_MIN; |
| 103 | + for (; nreg <= NREG_MAX; nreg += NREG_STEP) { |
| 104 | + double time = bench(1, 1, nreg); |
| 105 | + std::cout << "Testing nreg=\t" << nreg << "\tTime=\t" << time << std::endl; |
| 106 | + if (dj.push(time)) { |
| 107 | + nreg -= NREG_STEP; |
| 108 | + nreg_max = nreg; |
| 109 | + break; |
| 110 | + } |
| 111 | + } |
| 112 | + if (nreg >= NREG_MAX) { |
| 113 | + std::cout << "Unable to conclude a maximal register count" << std::endl; |
| 114 | + nreg_max = NREG_STEP; |
| 115 | + } else { |
| 116 | + std::cout << nreg_max << " available at most" << std::endl; |
| 117 | + } |
46 | 118 | }
|
47 | 119 |
|
48 | 120 | int main(int argc, const char** argv) {
|
|
0 commit comments