Skip to content

Commit b5376f3

Browse files
jinge90bb-sycl
authored andcommitted
Use function object to do equality check in imf utils. (intel#1567)
Previously, we simply used "==" to compare the result with reference value in imf tests which is not suitable for floating point number. So, replace it with function object to provide more flexibility.
1 parent e16ca1f commit b5376f3

File tree

1 file changed

+24
-11
lines changed

1 file changed

+24
-11
lines changed

SYCL/DeviceLib/imf_utils.hpp

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,19 @@ typedef _Float16 _iml_half_internal;
1313
typedef uint16_t _iml_half_internal;
1414
#endif
1515

16+
template <class Ty> class imf_utils_default_equ {
17+
public:
18+
bool operator()(Ty x, Ty y) {
19+
if constexpr (std::is_same_v<Ty, sycl::half2>) {
20+
return (x.s0() == y.s0()) && (x.s1() == y.s1());
21+
} else
22+
return x == y;
23+
};
24+
};
25+
1626
// Used to test half precision utils
17-
template <class InputTy, class OutputTy, class FuncTy>
27+
template <class InputTy, class OutputTy, class FuncTy,
28+
class EquTy = imf_utils_default_equ<OutputTy>>
1829
void test_host(std::initializer_list<InputTy> Input,
1930
std::initializer_list<OutputTy> RefOutput, FuncTy Func,
2031
int Line = __builtin_LINE()) {
@@ -24,7 +35,7 @@ void test_host(std::initializer_list<InputTy> Input,
2435
for (int i = 0; i < Size; ++i) {
2536
auto Expected = *(std::begin(RefOutput) + i);
2637
auto Res = Func(*(std::begin(Input) + i));
27-
if (Expected == Res)
38+
if (EquTy()(Expected, Res))
2839
continue;
2940

3041
std::cout << "Mismatch at line " << Line << "[" << i << "]: " << Res
@@ -33,7 +44,8 @@ void test_host(std::initializer_list<InputTy> Input,
3344
}
3445
}
3546

36-
template <class InputTy, class OutputTy, class FuncTy>
47+
template <class InputTy, class OutputTy, class FuncTy,
48+
class EquTy = imf_utils_default_equ<OutputTy>>
3749
void test(sycl::queue &q, std::initializer_list<InputTy> Input,
3850
std::initializer_list<OutputTy> RefOutput, FuncTy Func,
3951
int Line = __builtin_LINE()) {
@@ -60,15 +72,15 @@ void test(sycl::queue &q, std::initializer_list<InputTy> Input,
6072
for (int i = 0; i < Size; ++i) {
6173
auto Expected = *(std::begin(RefOutput) + i);
6274
if constexpr (std::is_same_v<OutputTy, sycl::half2>) {
63-
if ((Expected.s0() == Acc[i].s0()) && (Expected.s1() == Acc[i].s1()))
75+
if (EquTy()(Expected, Acc[i]))
6476
continue;
6577
std::cout << "Mismatch at line " << Line << "[" << i << "]: ("
6678
<< Acc[i].s0() << ", " << Acc[i].s1() << ")"
6779
<< " != (" << Expected.s0() << ", " << Expected.s1() << ")"
6880
<< ", input was (" << (*(std::begin(Input) + i)).s0() << ", "
6981
<< (*(std::begin(Input) + i)).s1() << ")" << std::endl;
7082
} else {
71-
if (Expected == Acc[i])
83+
if (EquTy()(Expected, Acc[i]))
7284
continue;
7385
std::cout << "Mismatch at line " << Line << "[" << i << "]: " << Acc[i]
7486
<< " != " << Expected << ", input was "
@@ -78,7 +90,8 @@ void test(sycl::queue &q, std::initializer_list<InputTy> Input,
7890
}
7991
}
8092

81-
template <class InputTy, class OutputTy, class FuncTy>
93+
template <class InputTy, class OutputTy, class FuncTy,
94+
class EquTy = imf_utils_default_equ<OutputTy>>
8295
void test2(sycl::queue &q, std::initializer_list<InputTy> Input1,
8396
std::initializer_list<InputTy> Input2,
8497
std::initializer_list<OutputTy> RefOutput, FuncTy Func,
@@ -112,14 +125,14 @@ void test2(sycl::queue &q, std::initializer_list<InputTy> Input1,
112125
for (int i = 0; i < Size; ++i) {
113126
auto Expected = *(std::begin(RefOutput) + i);
114127
if constexpr (std::is_same_v<OutputTy, sycl::half2>) {
115-
if ((Expected.s0() == Acc[i].s0()) && (Expected.s1() == Acc[i].s1()))
128+
if (EquTy()(Expected, Acc[i]))
116129
continue;
117130
std::cout << "Mismatch at line " << Line << "[" << i << "]: ("
118131
<< Acc[i].s0() << ", " << Acc[i].s1() << ")"
119132
<< " != (" << Expected.s0() << ", " << Expected.s1()
120133
<< "), input idx was " << i << std::endl;
121134
} else {
122-
if (Expected == Acc[i])
135+
if (EquTy()(Expected, Acc[i]))
123136
continue;
124137
std::cout << "Mismatch at line " << Line << "[" << i << "]: " << Acc[i]
125138
<< " != " << Expected << ", input idx was " << i << std::endl;
@@ -129,7 +142,7 @@ void test2(sycl::queue &q, std::initializer_list<InputTy> Input1,
129142
}
130143

131144
template <class InputTy1, class InputTy2, class InputTy3, class OutputTy,
132-
class FuncTy>
145+
class FuncTy, class EquTy = imf_utils_default_equ<OutputTy>>
133146
void test3(sycl::queue &q, std::initializer_list<InputTy1> Input1,
134147
std::initializer_list<InputTy2> Input2,
135148
std::initializer_list<InputTy3> Input3,
@@ -168,14 +181,14 @@ void test3(sycl::queue &q, std::initializer_list<InputTy1> Input1,
168181
for (int i = 0; i < Size; ++i) {
169182
auto Expected = *(std::begin(RefOutput) + i);
170183
if constexpr (std::is_same_v<OutputTy, sycl::half2>) {
171-
if ((Expected.s0() == Acc[i].s0()) && (Expected.s1() == Acc[i].s1()))
184+
if (EquTy()(Expected, Acc[i]))
172185
continue;
173186
std::cout << "Mismatch at line " << Line << "[" << i << "]: ("
174187
<< Acc[i].s0() << ", " << Acc[i].s1() << ")"
175188
<< " != (" << Expected.s0() << ", " << Expected.s1()
176189
<< "), input idx was " << i << std::endl;
177190
} else {
178-
if (Expected == Acc[i])
191+
if (EquTy()(Expected, Acc[i]))
179192
continue;
180193
std::cout << "Mismatch at line " << Line << "[" << i << "]: " << Acc[i]
181194
<< " != " << Expected << ", input was "

0 commit comments

Comments
 (0)