Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Use function object to do equality check in imf utils. #1567

Merged
merged 1 commit into from
Feb 2, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions SYCL/DeviceLib/imf_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,19 @@ typedef _Float16 _iml_half_internal;
typedef uint16_t _iml_half_internal;
#endif

template <class Ty> class imf_utils_default_equ {
public:
bool operator()(Ty x, Ty y) {
if constexpr (std::is_same_v<Ty, sycl::half2>) {
return (x.s0() == y.s0()) && (x.s1() == y.s1());
} else
return x == y;
};
};

// Used to test half precision utils
template <class InputTy, class OutputTy, class FuncTy>
template <class InputTy, class OutputTy, class FuncTy,
class EquTy = imf_utils_default_equ<OutputTy>>
void test_host(std::initializer_list<InputTy> Input,
std::initializer_list<OutputTy> RefOutput, FuncTy Func,
int Line = __builtin_LINE()) {
Expand All @@ -24,7 +35,7 @@ void test_host(std::initializer_list<InputTy> Input,
for (int i = 0; i < Size; ++i) {
auto Expected = *(std::begin(RefOutput) + i);
auto Res = Func(*(std::begin(Input) + i));
if (Expected == Res)
if (EquTy()(Expected, Res))
continue;

std::cout << "Mismatch at line " << Line << "[" << i << "]: " << Res
Expand All @@ -33,7 +44,8 @@ void test_host(std::initializer_list<InputTy> Input,
}
}

template <class InputTy, class OutputTy, class FuncTy>
template <class InputTy, class OutputTy, class FuncTy,
class EquTy = imf_utils_default_equ<OutputTy>>
void test(sycl::queue &q, std::initializer_list<InputTy> Input,
std::initializer_list<OutputTy> RefOutput, FuncTy Func,
int Line = __builtin_LINE()) {
Expand All @@ -60,15 +72,15 @@ void test(sycl::queue &q, std::initializer_list<InputTy> Input,
for (int i = 0; i < Size; ++i) {
auto Expected = *(std::begin(RefOutput) + i);
if constexpr (std::is_same_v<OutputTy, sycl::half2>) {
if ((Expected.s0() == Acc[i].s0()) && (Expected.s1() == Acc[i].s1()))
if (EquTy()(Expected, Acc[i]))
continue;
std::cout << "Mismatch at line " << Line << "[" << i << "]: ("
<< Acc[i].s0() << ", " << Acc[i].s1() << ")"
<< " != (" << Expected.s0() << ", " << Expected.s1() << ")"
<< ", input was (" << (*(std::begin(Input) + i)).s0() << ", "
<< (*(std::begin(Input) + i)).s1() << ")" << std::endl;
} else {
if (Expected == Acc[i])
if (EquTy()(Expected, Acc[i]))
continue;
std::cout << "Mismatch at line " << Line << "[" << i << "]: " << Acc[i]
<< " != " << Expected << ", input was "
Expand All @@ -78,7 +90,8 @@ void test(sycl::queue &q, std::initializer_list<InputTy> Input,
}
}

template <class InputTy, class OutputTy, class FuncTy>
template <class InputTy, class OutputTy, class FuncTy,
class EquTy = imf_utils_default_equ<OutputTy>>
void test2(sycl::queue &q, std::initializer_list<InputTy> Input1,
std::initializer_list<InputTy> Input2,
std::initializer_list<OutputTy> RefOutput, FuncTy Func,
Expand Down Expand Up @@ -112,14 +125,14 @@ void test2(sycl::queue &q, std::initializer_list<InputTy> Input1,
for (int i = 0; i < Size; ++i) {
auto Expected = *(std::begin(RefOutput) + i);
if constexpr (std::is_same_v<OutputTy, sycl::half2>) {
if ((Expected.s0() == Acc[i].s0()) && (Expected.s1() == Acc[i].s1()))
if (EquTy()(Expected, Acc[i]))
continue;
std::cout << "Mismatch at line " << Line << "[" << i << "]: ("
<< Acc[i].s0() << ", " << Acc[i].s1() << ")"
<< " != (" << Expected.s0() << ", " << Expected.s1()
<< "), input idx was " << i << std::endl;
} else {
if (Expected == Acc[i])
if (EquTy()(Expected, Acc[i]))
continue;
std::cout << "Mismatch at line " << Line << "[" << i << "]: " << Acc[i]
<< " != " << Expected << ", input idx was " << i << std::endl;
Expand All @@ -129,7 +142,7 @@ void test2(sycl::queue &q, std::initializer_list<InputTy> Input1,
}

template <class InputTy1, class InputTy2, class InputTy3, class OutputTy,
class FuncTy>
class FuncTy, class EquTy = imf_utils_default_equ<OutputTy>>
void test3(sycl::queue &q, std::initializer_list<InputTy1> Input1,
std::initializer_list<InputTy2> Input2,
std::initializer_list<InputTy3> Input3,
Expand Down Expand Up @@ -168,14 +181,14 @@ void test3(sycl::queue &q, std::initializer_list<InputTy1> Input1,
for (int i = 0; i < Size; ++i) {
auto Expected = *(std::begin(RefOutput) + i);
if constexpr (std::is_same_v<OutputTy, sycl::half2>) {
if ((Expected.s0() == Acc[i].s0()) && (Expected.s1() == Acc[i].s1()))
if (EquTy()(Expected, Acc[i]))
continue;
std::cout << "Mismatch at line " << Line << "[" << i << "]: ("
<< Acc[i].s0() << ", " << Acc[i].s1() << ")"
<< " != (" << Expected.s0() << ", " << Expected.s1()
<< "), input idx was " << i << std::endl;
} else {
if (Expected == Acc[i])
if (EquTy()(Expected, Acc[i]))
continue;
std::cout << "Mismatch at line " << Line << "[" << i << "]: " << Acc[i]
<< " != " << Expected << ", input was "
Expand Down