@@ -95,6 +95,22 @@ template <typename ValueT> struct should_skip {
95
95
}
96
96
};
97
97
98
+ #define CHECK (ResultT, RESULT_PTR, EXPECTED ) \
99
+ if constexpr (std::is_integral_v<ResultT>) { \
100
+ assert (*RESULT_PTR == EXPECTED); \
101
+ } else if constexpr (std::is_floating_point_v<ResultT> || \
102
+ std::is_same_v<ResultT, sycl::half>) { \
103
+ if (sycl::isnan (*RESULT_PTR)) \
104
+ assert (sycl::isnan (EXPECTED)); \
105
+ else \
106
+ assert (fabs (*RESULT_PTR - EXPECTED) < ERROR_TOLERANCE); \
107
+ } else if constexpr (contained_is_floating_point_v<ResultT>) { \
108
+ for (size_t i = 0 ; i < RESULT_PTR->size (); i++) \
109
+ assert (fabs ((*RESULT_PTR)[i] - EXPECTED[i]) < ERROR_TOLERANCE); \
110
+ } else { \
111
+ static_assert (0 , " Math_fixt.hpp should not have arrived here." ); \
112
+ }
113
+
98
114
class OpTestLauncher {
99
115
protected:
100
116
syclcompat::dim3 grid_;
@@ -107,14 +123,14 @@ class OpTestLauncher {
107
123
: grid_{grid}, threads_{threads}, skip_{skip} {}
108
124
};
109
125
110
- // Templated TRes to support both arithmetic and boolean operators
126
+ // Templated ResultT to support both arithmetic and boolean operators
111
127
template <typename ValueT, typename ValueU,
112
- typename TRes = std::common_type_t <ValueT, ValueU>>
128
+ typename ResultT = std::common_type_t <ValueT, ValueU>>
113
129
class BinaryOpTestLauncher : OpTestLauncher {
114
130
protected:
115
131
ValueT *op1_;
116
132
ValueU *op2_;
117
- TRes *res_;
133
+ ResultT *res_;
118
134
119
135
public:
120
136
BinaryOpTestLauncher (const syclcompat::dim3 &grid,
@@ -127,7 +143,7 @@ class BinaryOpTestLauncher : OpTestLauncher {
127
143
return ;
128
144
op1_ = syclcompat::malloc_shared<ValueT>(data_size);
129
145
op2_ = syclcompat::malloc_shared<ValueU>(data_size);
130
- res_ = syclcompat::malloc_shared<TRes >(data_size);
146
+ res_ = syclcompat::malloc_shared<ResultT >(data_size);
131
147
};
132
148
133
149
virtual ~BinaryOpTestLauncher () {
@@ -139,28 +155,19 @@ class BinaryOpTestLauncher : OpTestLauncher {
139
155
}
140
156
141
157
template <auto Kernel>
142
- void launch_test (ValueT op1, ValueU op2, TRes expected) {
158
+ void launch_test (ValueT op1, ValueU op2, ResultT expected) {
143
159
if (skip_)
144
160
return ;
145
161
*op1_ = op1;
146
162
*op2_ = op2;
147
163
syclcompat::launch<Kernel>(grid_, threads_, op1_, op2_, res_);
148
164
syclcompat::wait ();
149
165
150
- if constexpr (std::is_integral_v<TRes>)
151
- assert (*res_ == expected);
152
- else if constexpr (std::is_floating_point_v<TRes> ||
153
- std::is_same_v<TRes, sycl::half>)
154
- assert (fabs (*res_ - expected) < ERROR_TOLERANCE);
155
- else if constexpr (contained_is_floating_point_v<TRes>) // Container
156
- for (size_t i = 0 ; i < res_->size (); i++)
157
- assert (fabs ((*res_)[i] - expected[i]) < ERROR_TOLERANCE);
158
- else
159
- assert (0 ); // If arrived here, no results where checked
160
- }
166
+ CHECK (ResultT, res_, expected);
167
+ };
161
168
};
162
169
163
- template <typename ValueT, typename TRes = ValueT>
170
+ template <typename ValueT, typename ResultT = ValueT>
164
171
class UnaryOpTestLauncher : OpTestLauncher {
165
172
protected:
166
173
ValueT *op_;
@@ -186,22 +193,13 @@ class UnaryOpTestLauncher : OpTestLauncher {
186
193
syclcompat::free (res_);
187
194
}
188
195
189
- template <auto Kernel> void launch_test (ValueT op, TRes expected) {
196
+ template <auto Kernel> void launch_test (ValueT op, ResultT expected) {
190
197
if (skip_)
191
198
return ;
192
199
*op_ = op;
193
200
syclcompat::launch<Kernel>(grid_, threads_, op_, res_);
194
201
syclcompat::wait ();
195
202
196
- if constexpr (std::is_integral_v<TRes>)
197
- assert (*res_ == expected);
198
- else if constexpr (std::is_floating_point_v<TRes> ||
199
- std::is_same_v<TRes, sycl::half>)
200
- assert (fabs (*res_ - expected) < ERROR_TOLERANCE);
201
- else if constexpr (contained_is_floating_point_v<TRes>) // Container
202
- for (size_t i = 0 ; i < res_->size (); i++)
203
- assert (fabs ((*res_)[i] - expected[i]) < ERROR_TOLERANCE);
204
- else
205
- assert (0 ); // If arrived here, no results where checked
203
+ CHECK (ResultT, res_, expected);
206
204
}
207
205
};
0 commit comments