@@ -13,8 +13,19 @@ typedef _Float16 _iml_half_internal;
13
13
typedef uint16_t _iml_half_internal;
14
14
#endif
15
15
16
+ template <class Ty > class imf_utils_default_equ {
17
+ public:
18
+ bool operator ()(Ty x, Ty y) {
19
+ if constexpr (std::is_same_v<Ty, sycl::half2>) {
20
+ return (x.s0 () == y.s0 ()) && (x.s1 () == y.s1 ());
21
+ } else
22
+ return x == y;
23
+ };
24
+ };
25
+
16
26
// Used to test half precision utils
17
- template <class InputTy , class OutputTy , class FuncTy >
27
+ template <class InputTy , class OutputTy , class FuncTy ,
28
+ class EquTy = imf_utils_default_equ<OutputTy>>
18
29
void test_host (std::initializer_list<InputTy> Input,
19
30
std::initializer_list<OutputTy> RefOutput, FuncTy Func,
20
31
int Line = __builtin_LINE()) {
@@ -24,7 +35,7 @@ void test_host(std::initializer_list<InputTy> Input,
24
35
for (int i = 0 ; i < Size; ++i) {
25
36
auto Expected = *(std::begin (RefOutput) + i);
26
37
auto Res = Func (*(std::begin (Input) + i));
27
- if (Expected == Res)
38
+ if (EquTy ()( Expected, Res) )
28
39
continue ;
29
40
30
41
std::cout << " Mismatch at line " << Line << " [" << i << " ]: " << Res
@@ -33,7 +44,8 @@ void test_host(std::initializer_list<InputTy> Input,
33
44
}
34
45
}
35
46
36
- template <class InputTy , class OutputTy , class FuncTy >
47
+ template <class InputTy , class OutputTy , class FuncTy ,
48
+ class EquTy = imf_utils_default_equ<OutputTy>>
37
49
void test (sycl::queue &q, std::initializer_list<InputTy> Input,
38
50
std::initializer_list<OutputTy> RefOutput, FuncTy Func,
39
51
int Line = __builtin_LINE()) {
@@ -60,15 +72,15 @@ void test(sycl::queue &q, std::initializer_list<InputTy> Input,
60
72
for (int i = 0 ; i < Size; ++i) {
61
73
auto Expected = *(std::begin (RefOutput) + i);
62
74
if constexpr (std::is_same_v<OutputTy, sycl::half2>) {
63
- if ((Expected. s0 () == Acc[i]. s0 ()) && ( Expected. s1 () == Acc[i]. s1 () ))
75
+ if (EquTy ()( Expected, Acc[i]))
64
76
continue ;
65
77
std::cout << " Mismatch at line " << Line << " [" << i << " ]: ("
66
78
<< Acc[i].s0 () << " , " << Acc[i].s1 () << " )"
67
79
<< " != (" << Expected.s0 () << " , " << Expected.s1 () << " )"
68
80
<< " , input was (" << (*(std::begin (Input) + i)).s0 () << " , "
69
81
<< (*(std::begin (Input) + i)).s1 () << " )" << std::endl;
70
82
} else {
71
- if (Expected == Acc[i])
83
+ if (EquTy ()( Expected, Acc[i]) )
72
84
continue ;
73
85
std::cout << " Mismatch at line " << Line << " [" << i << " ]: " << Acc[i]
74
86
<< " != " << Expected << " , input was "
@@ -78,7 +90,8 @@ void test(sycl::queue &q, std::initializer_list<InputTy> Input,
78
90
}
79
91
}
80
92
81
- template <class InputTy , class OutputTy , class FuncTy >
93
+ template <class InputTy , class OutputTy , class FuncTy ,
94
+ class EquTy = imf_utils_default_equ<OutputTy>>
82
95
void test2 (sycl::queue &q, std::initializer_list<InputTy> Input1,
83
96
std::initializer_list<InputTy> Input2,
84
97
std::initializer_list<OutputTy> RefOutput, FuncTy Func,
@@ -112,14 +125,14 @@ void test2(sycl::queue &q, std::initializer_list<InputTy> Input1,
112
125
for (int i = 0 ; i < Size; ++i) {
113
126
auto Expected = *(std::begin (RefOutput) + i);
114
127
if constexpr (std::is_same_v<OutputTy, sycl::half2>) {
115
- if ((Expected. s0 () == Acc[i]. s0 ()) && ( Expected. s1 () == Acc[i]. s1 () ))
128
+ if (EquTy ()( Expected, Acc[i]))
116
129
continue ;
117
130
std::cout << " Mismatch at line " << Line << " [" << i << " ]: ("
118
131
<< Acc[i].s0 () << " , " << Acc[i].s1 () << " )"
119
132
<< " != (" << Expected.s0 () << " , " << Expected.s1 ()
120
133
<< " ), input idx was " << i << std::endl;
121
134
} else {
122
- if (Expected == Acc[i])
135
+ if (EquTy ()( Expected, Acc[i]) )
123
136
continue ;
124
137
std::cout << " Mismatch at line " << Line << " [" << i << " ]: " << Acc[i]
125
138
<< " != " << Expected << " , input idx was " << i << std::endl;
@@ -129,7 +142,7 @@ void test2(sycl::queue &q, std::initializer_list<InputTy> Input1,
129
142
}
130
143
131
144
template <class InputTy1 , class InputTy2 , class InputTy3 , class OutputTy ,
132
- class FuncTy >
145
+ class FuncTy , class EquTy = imf_utils_default_equ<OutputTy> >
133
146
void test3 (sycl::queue &q, std::initializer_list<InputTy1> Input1,
134
147
std::initializer_list<InputTy2> Input2,
135
148
std::initializer_list<InputTy3> Input3,
@@ -168,14 +181,14 @@ void test3(sycl::queue &q, std::initializer_list<InputTy1> Input1,
168
181
for (int i = 0 ; i < Size; ++i) {
169
182
auto Expected = *(std::begin (RefOutput) + i);
170
183
if constexpr (std::is_same_v<OutputTy, sycl::half2>) {
171
- if ((Expected. s0 () == Acc[i]. s0 ()) && ( Expected. s1 () == Acc[i]. s1 () ))
184
+ if (EquTy ()( Expected, Acc[i]))
172
185
continue ;
173
186
std::cout << " Mismatch at line " << Line << " [" << i << " ]: ("
174
187
<< Acc[i].s0 () << " , " << Acc[i].s1 () << " )"
175
188
<< " != (" << Expected.s0 () << " , " << Expected.s1 ()
176
189
<< " ), input idx was " << i << std::endl;
177
190
} else {
178
- if (Expected == Acc[i])
191
+ if (EquTy ()( Expected, Acc[i]) )
179
192
continue ;
180
193
std::cout << " Mismatch at line " << Line << " [" << i << " ]: " << Acc[i]
181
194
<< " != " << Expected << " , input was "
0 commit comments