Skip to content

Commit 2006ebf

Browse files
committed
[SYCL][COMPAT] Fixed NaN values causing math tests to fail
1 parent 5d6e53c commit 2006ebf

File tree

1 file changed

+26
-28
lines changed

1 file changed

+26
-28
lines changed

sycl/test-e2e/syclcompat/math/math_fixt.hpp

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,22 @@ template <typename ValueT> struct should_skip {
9595
}
9696
};
9797

98+
#define CHECK(ResultT, RESULT_PTR, EXPECTED) \
99+
if constexpr (std::is_integral_v<ResultT>) { \
100+
assert(*RESULT_PTR == EXPECTED); \
101+
} else if constexpr (std::is_floating_point_v<ResultT> || \
102+
std::is_same_v<ResultT, sycl::half>) { \
103+
if (sycl::isnan(*RESULT_PTR)) \
104+
assert(sycl::isnan(EXPECTED)); \
105+
else \
106+
assert(fabs(*RESULT_PTR - EXPECTED) < ERROR_TOLERANCE); \
107+
} else if constexpr (contained_is_floating_point_v<ResultT>) { \
108+
for (size_t i = 0; i < RESULT_PTR->size(); i++) \
109+
assert(fabs((*RESULT_PTR)[i] - EXPECTED[i]) < ERROR_TOLERANCE); \
110+
} else { \
111+
static_assert(0, "Math_fixt.hpp should not have arrived here."); \
112+
}
113+
98114
class OpTestLauncher {
99115
protected:
100116
syclcompat::dim3 grid_;
@@ -107,14 +123,14 @@ class OpTestLauncher {
107123
: grid_{grid}, threads_{threads}, skip_{skip} {}
108124
};
109125

110-
// Templated TRes to support both arithmetic and boolean operators
126+
// Templated ResultT to support both arithmetic and boolean operators
111127
template <typename ValueT, typename ValueU,
112-
typename TRes = std::common_type_t<ValueT, ValueU>>
128+
typename ResultT = std::common_type_t<ValueT, ValueU>>
113129
class BinaryOpTestLauncher : OpTestLauncher {
114130
protected:
115131
ValueT *op1_;
116132
ValueU *op2_;
117-
TRes *res_;
133+
ResultT *res_;
118134

119135
public:
120136
BinaryOpTestLauncher(const syclcompat::dim3 &grid,
@@ -127,7 +143,7 @@ class BinaryOpTestLauncher : OpTestLauncher {
127143
return;
128144
op1_ = syclcompat::malloc_shared<ValueT>(data_size);
129145
op2_ = syclcompat::malloc_shared<ValueU>(data_size);
130-
res_ = syclcompat::malloc_shared<TRes>(data_size);
146+
res_ = syclcompat::malloc_shared<ResultT>(data_size);
131147
};
132148

133149
virtual ~BinaryOpTestLauncher() {
@@ -139,28 +155,19 @@ class BinaryOpTestLauncher : OpTestLauncher {
139155
}
140156

141157
template <auto Kernel>
142-
void launch_test(ValueT op1, ValueU op2, TRes expected) {
158+
void launch_test(ValueT op1, ValueU op2, ResultT expected) {
143159
if (skip_)
144160
return;
145161
*op1_ = op1;
146162
*op2_ = op2;
147163
syclcompat::launch<Kernel>(grid_, threads_, op1_, op2_, res_);
148164
syclcompat::wait();
149165

150-
if constexpr (std::is_integral_v<TRes>)
151-
assert(*res_ == expected);
152-
else if constexpr (std::is_floating_point_v<TRes> ||
153-
std::is_same_v<TRes, sycl::half>)
154-
assert(fabs(*res_ - expected) < ERROR_TOLERANCE);
155-
else if constexpr (contained_is_floating_point_v<TRes>) // Container
156-
for (size_t i = 0; i < res_->size(); i++)
157-
assert(fabs((*res_)[i] - expected[i]) < ERROR_TOLERANCE);
158-
else
159-
assert(0); // If arrived here, no results where checked
160-
}
166+
CHECK(ResultT, res_, expected);
167+
};
161168
};
162169

163-
template <typename ValueT, typename TRes = ValueT>
170+
template <typename ValueT, typename ResultT = ValueT>
164171
class UnaryOpTestLauncher : OpTestLauncher {
165172
protected:
166173
ValueT *op_;
@@ -186,22 +193,13 @@ class UnaryOpTestLauncher : OpTestLauncher {
186193
syclcompat::free(res_);
187194
}
188195

189-
template <auto Kernel> void launch_test(ValueT op, TRes expected) {
196+
template <auto Kernel> void launch_test(ValueT op, ResultT expected) {
190197
if (skip_)
191198
return;
192199
*op_ = op;
193200
syclcompat::launch<Kernel>(grid_, threads_, op_, res_);
194201
syclcompat::wait();
195202

196-
if constexpr (std::is_integral_v<TRes>)
197-
assert(*res_ == expected);
198-
else if constexpr (std::is_floating_point_v<TRes> ||
199-
std::is_same_v<TRes, sycl::half>)
200-
assert(fabs(*res_ - expected) < ERROR_TOLERANCE);
201-
else if constexpr (contained_is_floating_point_v<TRes>) // Container
202-
for (size_t i = 0; i < res_->size(); i++)
203-
assert(fabs((*res_)[i] - expected[i]) < ERROR_TOLERANCE);
204-
else
205-
assert(0); // If arrived here, no results where checked
203+
CHECK(ResultT, res_, expected);
206204
}
207205
};

0 commit comments

Comments
 (0)