@@ -118,6 +118,15 @@ inline constexpr RetT extend_binary(AT a, BT b, CT c,
118
118
return second_op (extend_temp, extend_c);
119
119
}
120
120
121
+ template <typename ValueT> inline bool isnan (const ValueT a) {
122
+ return sycl::isnan (a);
123
+ }
124
+ #ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS
125
+ inline bool isnan (const sycl::ext::oneapi::bfloat16 a) {
126
+ return sycl::ext::oneapi::experimental::isnan (a);
127
+ }
128
+ #endif
129
+
121
130
} // namespace detail
122
131
123
132
// / Compute fast_length for variable-length array
@@ -167,6 +176,121 @@ inline ValueT length(const ValueT *a, const int len) {
167
176
}
168
177
}
169
178
179
+ // / Performs comparison.
180
+ // / \param [in] a The first value
181
+ // / \param [in] b The second value
182
+ // / \param [in] binary_op functor that implements the binary operation
183
+ // / \returns the comparison result
184
+ template <typename ValueT, class BinaryOperation >
185
+ inline std::enable_if_t <
186
+ std::is_same_v<std::invoke_result_t <BinaryOperation, ValueT, ValueT>, bool >,
187
+ bool >
188
+ compare (const ValueT a, const ValueT b, const BinaryOperation binary_op) {
189
+ return binary_op (a, b);
190
+ }
191
+ template <typename ValueT>
192
+ inline std::enable_if_t <
193
+ std::is_same_v<std::invoke_result_t <std::not_equal_to<>, ValueT, ValueT>,
194
+ bool >,
195
+ bool >
196
+ compare (const ValueT a, const ValueT b, const std::not_equal_to<> binary_op) {
197
+ return !detail::isnan (a) && !detail::isnan (b) && binary_op (a, b);
198
+ }
199
+
200
+ // / Performs 2 element comparison.
201
+ // / \param [in] a The first value
202
+ // / \param [in] b The second value
203
+ // / \param [in] binary_op functor that implements the binary operation
204
+ // / \returns the comparison result
205
+ template <typename ValueT, class BinaryOperation >
206
+ inline std::enable_if_t <ValueT::size() == 2 , ValueT>
207
+ compare (const ValueT a, const ValueT b, const BinaryOperation binary_op) {
208
+ return {compare (a[0 ], b[0 ], binary_op), compare (a[1 ], b[1 ], binary_op)};
209
+ }
210
+
211
+ // / Performs unordered comparison.
212
+ // / \param [in] a The first value
213
+ // / \param [in] b The second value
214
+ // / \param [in] binary_op functor that implements the binary operation
215
+ // / \returns the comparison result
216
+ template <typename ValueT, class BinaryOperation >
217
+ inline std::enable_if_t <
218
+ std::is_same_v<std::invoke_result_t <BinaryOperation, ValueT, ValueT>, bool >,
219
+ bool >
220
+ unordered_compare (const ValueT a, const ValueT b,
221
+ const BinaryOperation binary_op) {
222
+ return detail::isnan (a) || detail::isnan (b) || binary_op (a, b);
223
+ }
224
+
225
+ // / Performs 2 element unordered comparison.
226
+ // / \param [in] a The first value
227
+ // / \param [in] b The second value
228
+ // / \param [in] binary_op functor that implements the binary operation
229
+ // / \returns the comparison result
230
+ template <typename ValueT, class BinaryOperation >
231
+ inline std::enable_if_t <ValueT::size() == 2 , ValueT>
232
+ unordered_compare (const ValueT a, const ValueT b,
233
+ const BinaryOperation binary_op) {
234
+ return {unordered_compare (a[0 ], b[0 ], binary_op),
235
+ unordered_compare (a[1 ], b[1 ], binary_op)};
236
+ }
237
+
238
+ // / Performs 2 element comparison and return true if both results are true.
239
+ // / \param [in] a The first value
240
+ // / \param [in] b The second value
241
+ // / \param [in] binary_op functor that implements the binary operation
242
+ // / \returns the comparison result
243
+ template <typename ValueT, class BinaryOperation >
244
+ inline std::enable_if_t <ValueT::size() == 2 , bool >
245
+ compare_both (const ValueT a, const ValueT b, const BinaryOperation binary_op) {
246
+ return compare (a[0 ], b[0 ], binary_op) && compare (a[1 ], b[1 ], binary_op);
247
+ }
248
+
249
+ // / Performs 2 element unordered comparison and return true if both results are
250
+ // / true.
251
+ // / \param [in] a The first value
252
+ // / \param [in] b The second value
253
+ // / \param [in] binary_op functor that implements the binary operation
254
+ // / \returns the comparison result
255
+ template <typename ValueT, class BinaryOperation >
256
+ inline std::enable_if_t <ValueT::size() == 2 , bool >
257
+ unordered_compare_both (const ValueT a, const ValueT b,
258
+ const BinaryOperation binary_op) {
259
+ return unordered_compare (a[0 ], b[0 ], binary_op) &&
260
+ unordered_compare (a[1 ], b[1 ], binary_op);
261
+ }
262
+
263
+ // / Performs 2 elements comparison, compare result of each element is 0 (false)
264
+ // / or 0xffff (true), returns an unsigned int by composing compare result of two
265
+ // / elements.
266
+ // / \param [in] a The first value
267
+ // / \param [in] b The second value
268
+ // / \param [in] binary_op functor that implements the binary operation
269
+ // / \returns the comparison result
270
+ template <typename ValueT, class BinaryOperation >
271
+ inline unsigned compare_mask (const sycl::vec<ValueT, 2 > a,
272
+ const sycl::vec<ValueT, 2 > b,
273
+ const BinaryOperation binary_op) {
274
+ // Since compare returns 0 or 1, -compare will be 0x00000000 or 0xFFFFFFFF
275
+ return ((-compare (a[0 ], b[0 ], binary_op)) << 16 ) |
276
+ ((-compare (a[1 ], b[1 ], binary_op)) & 0xFFFF );
277
+ }
278
+
279
+ // / Performs 2 elements unordered comparison, compare result of each element is
280
+ // / 0 (false) or 0xffff (true), returns an unsigned int by composing compare
281
+ // / result of two elements.
282
+ // / \param [in] a The first value
283
+ // / \param [in] b The second value
284
+ // / \param [in] binary_op functor that implements the binary operation
285
+ // / \returns the comparison result
286
+ template <typename ValueT, class BinaryOperation >
287
+ inline unsigned unordered_compare_mask (const sycl::vec<ValueT, 2 > a,
288
+ const sycl::vec<ValueT, 2 > b,
289
+ const BinaryOperation binary_op) {
290
+ return ((-unordered_compare (a[0 ], b[0 ], binary_op)) << 16 ) |
291
+ ((-unordered_compare (a[1 ], b[1 ], binary_op)) & 0xFFFF );
292
+ }
293
+
170
294
// / Compute vectorized max for two values, with each value treated as a vector
171
295
// / type \p S
172
296
// / \param [in] S The type of the vector
0 commit comments