8
8
#include < random>
9
9
#include < stdlib.h>
10
10
11
- #define EXPECT_MATRIX_EQ (A, B, R, C ) \
12
- do { \
13
- for (unsigned r = 0 ; r < R; r++) \
14
- for (unsigned c = 0 ; c < C; c++) \
15
- if (A[r + c * R] != B[r + c * R]) { \
16
- std::cerr << " mismatch at " << r << " :" << c << " \n " ; \
17
- exit (1 ); \
18
- } \
19
- } while (false )
11
+ #define ABSTOL 0.000001
12
+ #define RELTOL 0.00001
13
+ bool fpcmp (double V1, double V2, double AbsTolerance, double RelTolerance) {
14
+ // Check to see if these are inside the absolute tolerance
15
+ if (AbsTolerance < fabs (V1 - V2)) {
16
+ // Nope, check the relative tolerance...
17
+ double Diff;
18
+ if (V2)
19
+ Diff = fabs (V1 / V2 - 1.0 );
20
+ else if (V1)
21
+ Diff = fabs (V2 / V1 - 1.0 );
22
+ else
23
+ Diff = 0 ; // Both zero.
24
+ if (Diff > RelTolerance) {
25
+ return true ;
26
+ }
27
+ }
28
+ return false ;
29
+ }
30
+
31
+ template <typename ElementTy, typename std::enable_if_t <
32
+ std::is_integral<ElementTy>::value, int > = 0 >
33
+ void expectMatrixEQ (ElementTy *A, ElementTy *B, unsigned R, unsigned C) {
34
+ do {
35
+ for (unsigned r = 0 ; r < R; r++)
36
+ for (unsigned c = 0 ; c < C; c++)
37
+ if (A[r + c * R] != B[r + c * R]) {
38
+ std::cerr << " mismatch at " << r << " :" << c << " \n " ;
39
+ exit (1 );
40
+ }
41
+ } while (false );
42
+ }
43
+
44
+ template <typename ElementTy,
45
+ typename std::enable_if_t <std::is_floating_point<ElementTy>::value,
46
+ int > = 0 >
47
+ void expectMatrixEQ (ElementTy *A, ElementTy *B, unsigned R,
48
+ unsigned C) {
49
+ do {
50
+ for (unsigned r = 0 ; r < R; r++)
51
+ for (unsigned c = 0 ; c < C; c++)
52
+ if (fpcmp (A[r + c * R], B[r + c * R], ABSTOL, RELTOL)) {
53
+ std::cerr << " mismatch at " << r << " :" << c << " \n " ;
54
+ exit (1 );
55
+ }
56
+ } while (false );
57
+ }
58
+
20
59
21
60
template <typename EltTy>
22
61
void zeroMatrix (EltTy *M, unsigned Rows, unsigned Cols) {
@@ -33,13 +72,25 @@ template <typename EltTy> void print(EltTy *X, unsigned Rows, unsigned Cols) {
33
72
}
34
73
}
35
74
36
- template <typename Ty> void initRandom (Ty *A, unsigned Rows, unsigned Cols) {
75
+ template <typename ElementTy,
76
+ typename std::enable_if_t <std::is_floating_point<ElementTy>::value,
77
+ int > = 0 >
78
+ void initRandom (ElementTy *A, unsigned Rows, unsigned Cols) {
79
+ std::default_random_engine generator;
80
+ std::uniform_real_distribution<ElementTy> distribution (-10.0 , 10.0 );
81
+
82
+ for (unsigned i = 0 ; i < Rows * Cols; i++)
83
+ A[i] = distribution (generator);
84
+ }
85
+
86
+ template <typename ElementTy, typename std::enable_if_t <
87
+ std::is_integral<ElementTy>::value, int > = 0 >
88
+ void initRandom (ElementTy *A, unsigned Rows, unsigned Cols) {
37
89
std::default_random_engine generator;
38
- std::uniform_int_distribution<double > distribution (-10.0 , 10.0 );
39
- auto random_double = std::bind (distribution, generator);
90
+ std::uniform_int_distribution<ElementTy> distribution (-10 , 10 );
40
91
41
92
for (unsigned i = 0 ; i < Rows * Cols; i++)
42
- A[i] = random_double ( );
93
+ A[i] = distribution (generator );
43
94
}
44
95
45
96
template <typename EltTy, unsigned R, unsigned C>
@@ -82,8 +133,8 @@ template <typename EltTy, unsigned R0, unsigned C0> void testTranspose() {
82
133
transposeSpec<EltTy, R0, C0>(ResSpec, X);
83
134
transposeBuiltin<EltTy, R0, C0>(ResBuiltin, X);
84
135
85
- EXPECT_MATRIX_EQ (ResBase, ResBuiltin, R0, C0);
86
- EXPECT_MATRIX_EQ (ResBase, ResSpec, C0, R0);
136
+ expectMatrixEQ (ResBase, ResBuiltin, R0, C0);
137
+ expectMatrixEQ (ResBase, ResSpec, C0, R0);
87
138
}
88
139
89
140
template <typename EltTy, unsigned R0, unsigned C0, unsigned C1>
@@ -150,9 +201,9 @@ void testMultiply() {
150
201
multiplySpec<EltTy, R0, C0, C1>(ResSpec, X, Y);
151
202
multiplyBuiltin<EltTy, R0, C0, C1>(ResBuiltin, X, Y);
152
203
153
- EXPECT_MATRIX_EQ (ResSpec, ResBuiltin, R0, C1);
154
- EXPECT_MATRIX_EQ (ResBase, ResBuiltin, R0, C1);
155
- EXPECT_MATRIX_EQ (ResBase, ResSpec, R0, C1);
204
+ expectMatrixEQ (ResSpec, ResBuiltin, R0, C1);
205
+ expectMatrixEQ (ResBase, ResBuiltin, R0, C1);
206
+ expectMatrixEQ (ResBase, ResSpec, R0, C1);
156
207
}
157
208
158
209
int main (void ) {
0 commit comments