Skip to content

Commit baf7059

Browse files
authored
[SYCL][ESIMD][EMU] tolerated mismatch rate in binary files comparison (intel#885)
* [SYCL][ESIMD][EMU] tolerated mismatch rate in binary files comparison (e.g. for cases resulted from deviations in GPU vs host-based FP computations).
1 parent 280a175 commit baf7059

File tree

1 file changed

+77
-16
lines changed

1 file changed

+77
-16
lines changed

SYCL/ESIMD/esimd_test_utils.hpp

Lines changed: 77 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <chrono>
1717
#include <cstring>
1818
#include <fstream>
19+
#include <iomanip>
1920
#include <iostream>
2021
#include <iterator>
2122
#include <string>
@@ -98,28 +99,88 @@ bool write_binary_file(const char *fname, const std::vector<T> &vec,
9899
return !ofs.bad();
99100
}
100101

101-
template <typename T>
102-
bool cmp_binary_files(const char *fname1, const char *fname2, T tolerance) {
103-
const auto vec1 = read_binary_file<T>(fname1);
104-
const auto vec2 = read_binary_file<T>(fname2);
105-
if (vec1.size() != vec2.size()) {
106-
std::cerr << fname1 << " size is " << vec1.size();
107-
std::cerr << " whereas " << fname2 << " size is " << vec2.size()
102+
template <typename T,
103+
typename std::enable_if<std::is_integral<T>::value ||
104+
std::is_floating_point<T>::value,
105+
int>::type = 0>
106+
bool cmp_binary_files(const char *testOutFile, const char *referenceFile,
107+
const T tolerance = 0,
108+
const double mismatchRateTolerance = 0,
109+
const int mismatchReportLimit = 9) {
110+
111+
if (mismatchRateTolerance) {
112+
if (mismatchRateTolerance >= 1 || mismatchRateTolerance < 0) {
113+
std::cerr << "Tolerated mismatch rate (" << mismatchRateTolerance
114+
<< ") must be set within [0, 1) range" << std::endl;
115+
return false;
116+
}
117+
118+
std::cerr << "Tolerated mismatch rate set to " << mismatchRateTolerance
108119
<< std::endl;
120+
}
121+
122+
const auto testVec = read_binary_file<T>(testOutFile);
123+
const auto referenceVec = read_binary_file<T>(referenceFile);
124+
125+
if (testVec.size() != referenceVec.size()) {
126+
std::cerr << testOutFile << " size is " << testVec.size();
127+
std::cerr << " whereas " << referenceFile << " size is "
128+
<< referenceVec.size() << std::endl;
109129
return false;
110130
}
111-
for (size_t i = 0; i < vec1.size(); i++) {
112-
if (abs(vec1[i] - vec2[i]) > tolerance) {
113-
std::cerr << "Mismatch at " << i << ' ';
114-
if (sizeof(T) == 1) {
115-
std::cerr << (int)vec1[i] << " vs " << (int)vec2[i] << std::endl;
116-
} else {
117-
std::cerr << vec1[i] << " vs " << vec2[i] << std::endl;
131+
132+
size_t totalMismatches = 0;
133+
const size_t size = testVec.size();
134+
double maxRelativeDiff = 0;
135+
bool status = true;
136+
for (size_t i = 0; i < size; i++) {
137+
const auto diff = abs(testVec[i] - referenceVec[i]);
138+
if (diff > tolerance) {
139+
if (!mismatchRateTolerance || (totalMismatches < mismatchReportLimit)) {
140+
141+
std::cerr << "Mismatch at " << i << ' ';
142+
if (sizeof(T) == 1) {
143+
std::cerr << (int)testVec[i] << " vs " << (int)referenceVec[i];
144+
} else {
145+
std::cerr << testVec[i] << " vs " << referenceVec[i];
146+
}
147+
148+
maxRelativeDiff = std::max(maxRelativeDiff,
149+
static_cast<double>(diff) / referenceVec[i]);
150+
151+
if (!mismatchRateTolerance) {
152+
std::cerr << std::endl;
153+
status = false;
154+
break;
155+
} else {
156+
std::cerr << ". Current mismatch rate: " << std::setprecision(8)
157+
<< std::fixed << static_cast<double>(totalMismatches) / size
158+
<< std::endl;
159+
}
160+
161+
} else if (totalMismatches == mismatchReportLimit) {
162+
std::cerr << "Mismatch output stopped." << std::endl;
118163
}
119-
return false;
164+
165+
totalMismatches++;
166+
}
167+
}
168+
169+
if (totalMismatches) {
170+
const auto totalMismatchRate = static_cast<double>(totalMismatches) / size;
171+
if (totalMismatchRate > mismatchRateTolerance) {
172+
std::cerr << "Mismatch rate of " << totalMismatchRate
173+
<< " has exceeded the tolerated amount of "
174+
<< mismatchRateTolerance << std::endl;
175+
status = false;
120176
}
177+
178+
std::cerr << "Total mismatch rate is " << totalMismatchRate
179+
<< " with max relative difference of " << maxRelativeDiff
180+
<< std::endl;
121181
}
122-
return true;
182+
183+
return status;
123184
}
124185

125186
// dump every element of sequence [first, last) to std::cout

0 commit comments

Comments
 (0)