|
16 | 16 | #include <chrono>
|
17 | 17 | #include <cstring>
|
18 | 18 | #include <fstream>
|
| 19 | +#include <iomanip> |
19 | 20 | #include <iostream>
|
20 | 21 | #include <iterator>
|
21 | 22 | #include <string>
|
@@ -99,26 +100,73 @@ bool write_binary_file(const char *fname, const std::vector<T> &vec,
|
99 | 100 | }
|
100 | 101 |
|
101 | 102 | template <typename T>
|
102 |
| -bool cmp_binary_files(const char *fname1, const char *fname2, T tolerance) { |
| 103 | +bool cmp_binary_files(const char *fname1, const char *fname2, const T tolerance, |
| 104 | + const double toleratedMismatchRate = 0, |
| 105 | + const int mismatchReportThrottleLimit = 5) { |
| 106 | + |
| 107 | + if (toleratedMismatchRate) { |
| 108 | + if (toleratedMismatchRate >= 1 || toleratedMismatchRate < 0) { |
| 109 | + std::cerr << "Tolerated mismatch rate (" << toleratedMismatchRate |
| 110 | + << ") must be set within [0, 1) range" << std::endl; |
| 111 | + return false; |
| 112 | + } |
| 113 | + |
| 114 | + std::cerr << "Tolerated mismatch rate set to " << toleratedMismatchRate |
| 115 | + << std::endl; |
| 116 | + } |
| 117 | + |
103 | 118 | const auto vec1 = read_binary_file<T>(fname1);
|
104 | 119 | const auto vec2 = read_binary_file<T>(fname2);
|
| 120 | + |
105 | 121 | if (vec1.size() != vec2.size()) {
|
106 | 122 | std::cerr << fname1 << " size is " << vec1.size();
|
107 | 123 | std::cerr << " whereas " << fname2 << " size is " << vec2.size()
|
108 | 124 | << std::endl;
|
109 | 125 | return false;
|
110 | 126 | }
|
111 |
| - for (size_t i = 0; i < vec1.size(); i++) { |
| 127 | + |
| 128 | + double totalMismatches = 0; |
| 129 | + const double size = vec1.size(); |
| 130 | + for (size_t i = 0; i < size; i++) { |
112 | 131 | if (abs(vec1[i] - vec2[i]) > tolerance) {
|
113 |
| - std::cerr << "Mismatch at " << i << ' '; |
114 |
| - if (sizeof(T) == 1) { |
115 |
| - std::cerr << (int)vec1[i] << " vs " << (int)vec2[i] << std::endl; |
116 |
| - } else { |
117 |
| - std::cerr << vec1[i] << " vs " << vec2[i] << std::endl; |
| 132 | + if (!toleratedMismatchRate || |
| 133 | + (totalMismatches < mismatchReportThrottleLimit)) { |
| 134 | + |
| 135 | + std::cerr << "Mismatch at " << i << ' '; |
| 136 | + if (sizeof(T) == 1) { |
| 137 | + std::cerr << (int)vec1[i] << " vs " << (int)vec2[i]; |
| 138 | + } else { |
| 139 | + std::cerr << vec1[i] << " vs " << vec2[i]; |
| 140 | + } |
| 141 | + |
| 142 | + if (!toleratedMismatchRate) { |
| 143 | + std::cerr << std::endl; |
| 144 | + return false; |
| 145 | + } else { |
| 146 | + std::cerr << ". Current mismatch rate: " << std::setprecision(8) |
| 147 | + << std::fixed << totalMismatches / size << std::endl; |
| 148 | + } |
| 149 | + |
| 150 | + } else if (totalMismatches == mismatchReportThrottleLimit) { |
| 151 | + std::cerr << "Mismatch output throttled ... " << std::endl; |
118 | 152 | }
|
| 153 | + |
| 154 | + totalMismatches++; |
| 155 | + } |
| 156 | + } |
| 157 | + |
| 158 | + if (totalMismatches) { |
| 159 | + const auto totalMismatchRate = totalMismatches / size; |
| 160 | + if (totalMismatchRate > toleratedMismatchRate) { |
| 161 | + std::cerr << "Mismatch rate of " << totalMismatchRate |
| 162 | + << " has exceeded the tolerated amount of " |
| 163 | + << toleratedMismatchRate << std::endl; |
119 | 164 | return false;
|
120 | 165 | }
|
| 166 | + |
| 167 | + std::cerr << "Total mismatch rate is " << totalMismatchRate << std::endl; |
121 | 168 | }
|
| 169 | + |
122 | 170 | return true;
|
123 | 171 | }
|
124 | 172 |
|
|
0 commit comments