|
16 | 16 | #include <chrono>
|
17 | 17 | #include <cstring>
|
18 | 18 | #include <fstream>
|
| 19 | +#include <iomanip> |
19 | 20 | #include <iostream>
|
20 | 21 | #include <iterator>
|
21 | 22 | #include <string>
|
@@ -98,28 +99,88 @@ bool write_binary_file(const char *fname, const std::vector<T> &vec,
|
98 | 99 | return !ofs.bad();
|
99 | 100 | }
|
100 | 101 |
|
101 |
| -template <typename T> |
102 |
| -bool cmp_binary_files(const char *fname1, const char *fname2, T tolerance) { |
103 |
| - const auto vec1 = read_binary_file<T>(fname1); |
104 |
| - const auto vec2 = read_binary_file<T>(fname2); |
105 |
| - if (vec1.size() != vec2.size()) { |
106 |
| - std::cerr << fname1 << " size is " << vec1.size(); |
107 |
| - std::cerr << " whereas " << fname2 << " size is " << vec2.size() |
| 102 | +template <typename T, |
| 103 | + typename std::enable_if<std::is_integral<T>::value || |
| 104 | + std::is_floating_point<T>::value, |
| 105 | + int>::type = 0> |
| 106 | +bool cmp_binary_files(const char *testOutFile, const char *referenceFile, |
| 107 | + const T tolerance = 0, |
| 108 | + const double mismatchRateTolerance = 0, |
| 109 | + const int mismatchReportLimit = 9) { |
| 110 | + |
| 111 | + if (mismatchRateTolerance) { |
| 112 | + if (mismatchRateTolerance >= 1 || mismatchRateTolerance < 0) { |
| 113 | + std::cerr << "Tolerated mismatch rate (" << mismatchRateTolerance |
| 114 | + << ") must be set within [0, 1) range" << std::endl; |
| 115 | + return false; |
| 116 | + } |
| 117 | + |
| 118 | + std::cerr << "Tolerated mismatch rate set to " << mismatchRateTolerance |
108 | 119 | << std::endl;
|
| 120 | + } |
| 121 | + |
| 122 | + const auto testVec = read_binary_file<T>(testOutFile); |
| 123 | + const auto referenceVec = read_binary_file<T>(referenceFile); |
| 124 | + |
| 125 | + if (testVec.size() != referenceVec.size()) { |
| 126 | + std::cerr << testOutFile << " size is " << testVec.size(); |
| 127 | + std::cerr << " whereas " << referenceFile << " size is " |
| 128 | + << referenceVec.size() << std::endl; |
109 | 129 | return false;
|
110 | 130 | }
|
111 |
| - for (size_t i = 0; i < vec1.size(); i++) { |
112 |
| - if (abs(vec1[i] - vec2[i]) > tolerance) { |
113 |
| - std::cerr << "Mismatch at " << i << ' '; |
114 |
| - if (sizeof(T) == 1) { |
115 |
| - std::cerr << (int)vec1[i] << " vs " << (int)vec2[i] << std::endl; |
116 |
| - } else { |
117 |
| - std::cerr << vec1[i] << " vs " << vec2[i] << std::endl; |
| 131 | + |
| 132 | + size_t totalMismatches = 0; |
| 133 | + const size_t size = testVec.size(); |
| 134 | + double maxRelativeDiff = 0; |
| 135 | + bool status = true; |
| 136 | + for (size_t i = 0; i < size; i++) { |
| 137 | + const auto diff = abs(testVec[i] - referenceVec[i]); |
| 138 | + if (diff > tolerance) { |
| 139 | + if (!mismatchRateTolerance || (totalMismatches < mismatchReportLimit)) { |
| 140 | + |
| 141 | + std::cerr << "Mismatch at " << i << ' '; |
| 142 | + if (sizeof(T) == 1) { |
| 143 | + std::cerr << (int)testVec[i] << " vs " << (int)referenceVec[i]; |
| 144 | + } else { |
| 145 | + std::cerr << testVec[i] << " vs " << referenceVec[i]; |
| 146 | + } |
| 147 | + |
| 148 | + maxRelativeDiff = std::max(maxRelativeDiff, |
| 149 | + static_cast<double>(diff) / referenceVec[i]); |
| 150 | + |
| 151 | + if (!mismatchRateTolerance) { |
| 152 | + std::cerr << std::endl; |
| 153 | + status = false; |
| 154 | + break; |
| 155 | + } else { |
| 156 | + std::cerr << ". Current mismatch rate: " << std::setprecision(8) |
| 157 | + << std::fixed << static_cast<double>(totalMismatches) / size |
| 158 | + << std::endl; |
| 159 | + } |
| 160 | + |
| 161 | + } else if (totalMismatches == mismatchReportLimit) { |
| 162 | + std::cerr << "Mismatch output stopped." << std::endl; |
118 | 163 | }
|
119 |
| - return false; |
| 164 | + |
| 165 | + totalMismatches++; |
| 166 | + } |
| 167 | + } |
| 168 | + |
| 169 | + if (totalMismatches) { |
| 170 | + const auto totalMismatchRate = static_cast<double>(totalMismatches) / size; |
| 171 | + if (totalMismatchRate > mismatchRateTolerance) { |
| 172 | + std::cerr << "Mismatch rate of " << totalMismatchRate |
| 173 | + << " has exceeded the tolerated amount of " |
| 174 | + << mismatchRateTolerance << std::endl; |
| 175 | + status = false; |
120 | 176 | }
|
| 177 | + |
| 178 | + std::cerr << "Total mismatch rate is " << totalMismatchRate |
| 179 | + << " with max relative difference of " << maxRelativeDiff |
| 180 | + << std::endl; |
121 | 181 | }
|
122 |
| - return true; |
| 182 | + |
| 183 | + return status; |
123 | 184 | }
|
124 | 185 |
|
125 | 186 | // dump every element of sequence [first, last) to std::cout
|
|
0 commit comments