Skip to content

Commit af491ee

Browse files
authored
[SYCL][COMPAT] Added compare and unordered compare operations (#12998)
Adds the following ordered and unordered comparisons: - compare - undered_compare - compare_both (wrapper for sycl::vec<ValueT, 2>) - undered_compare_both (wrapper for sycl::vec<ValueT, 2>) - compare_mask - unordered_compare_mask
1 parent 4b14d70 commit af491ee

File tree

4 files changed

+541
-0
lines changed

4 files changed

+541
-0
lines changed

sycl/doc/syclcompat/README.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ Specifically, this library depends on the following SYCL extensions:
5353
If available, the following extensions extend SYCLcompat functionality:
5454

5555
* [sycl_ext_intel_device_info](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/supported/sycl_ext_intel_device_info.md) \[Optional\]
56+
* [sycl_ext_oneapi_bfloat16_math_functions](../extensions/experimental/sycl_ext_oneapi_bfloat16_math_functions.asciidoc) \[Optional\]
5657

5758
## Usage
5859

@@ -1275,6 +1276,10 @@ static kernel_function_info get_kernel_function_info(const void *function);
12751276
length. `syclcompat::length` provides a templated version that wraps over
12761277
`sycl::length`.
12771278

1279+
`compare`, `unordered_compare`, `compare_both`, `unordered_compare_both`,
1280+
`compare_mask`, and `unordered_compare_mask`, handle both ordered and unordered
1281+
comparisons.
1282+
12781283
`vectorized_max` and `vectorized_min` are binary operations returning the
12791284
max/min of two arguments, where each argument is treated as a `sycl::vec` type.
12801285
`vectorized_isgreater` performs elementwise `isgreater`, treating each argument
@@ -1292,6 +1297,45 @@ inline float fast_length(const float *a, int len);
12921297
template <typename ValueT>
12931298
inline ValueT length(const ValueT *a, const int len);
12941299

1300+
// The following definition is enabled when BinaryOperation(ValueT, ValueT) returns bool
1301+
// std::enable_if_t<std::is_same_v<std::invoke_result_t<BinaryOperation, ValueT, ValueT>, bool>, bool>
1302+
template <typename ValueT, class BinaryOperation>
1303+
inline bool
1304+
compare(const ValueT a, const ValueT b, const BinaryOperation binary_op);
1305+
template <typename ValueT, class BinaryOperation>
1306+
inline std::enable_if_t<ValueT::size() == 2, ValueT>
1307+
compare(const ValueT a, const ValueT b, const BinaryOperation binary_op);
1308+
1309+
// The following definition is enabled when BinaryOperation(ValueT, ValueT) returns bool
1310+
// std::enable_if_t<std::is_same_v<std::invoke_result_t<BinaryOperation, ValueT, ValueT>, bool>, bool>
1311+
template <typename ValueT, class BinaryOperation>
1312+
inline bool
1313+
unordered_compare(const ValueT a, const ValueT b,
1314+
const BinaryOperation binary_op);
1315+
template <typename ValueT, class BinaryOperation>
1316+
inline std::enable_if_t<ValueT::size() == 2, ValueT>
1317+
unordered_compare(const ValueT a, const ValueT b,
1318+
const BinaryOperation binary_op);
1319+
1320+
template <typename ValueT, class BinaryOperation>
1321+
inline std::enable_if_t<ValueT::size() == 2, bool>
1322+
compare_both(const ValueT a, const ValueT b, const BinaryOperation binary_op);
1323+
template <typename ValueT, class BinaryOperation>
1324+
1325+
inline std::enable_if_t<ValueT::size() == 2, bool>
1326+
unordered_compare_both(const ValueT a, const ValueT b,
1327+
const BinaryOperation binary_op);
1328+
1329+
template <typename ValueT, class BinaryOperation>
1330+
inline unsigned compare_mask(const sycl::vec<ValueT, 2> a,
1331+
const sycl::vec<ValueT, 2> b,
1332+
const BinaryOperation binary_op);
1333+
1334+
template <typename ValueT, class BinaryOperation>
1335+
inline unsigned unordered_compare_mask(const sycl::vec<ValueT, 2> a,
1336+
const sycl::vec<ValueT, 2> b,
1337+
const BinaryOperation binary_op);
1338+
12951339
template <typename S, typename T> inline T vectorized_max(T a, T b);
12961340

12971341
template <typename S, typename T> inline T vectorized_min(T a, T b);

sycl/include/syclcompat/math.hpp

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,15 @@ inline constexpr RetT extend_binary(AT a, BT b, CT c,
118118
return second_op(extend_temp, extend_c);
119119
}
120120

121+
template <typename ValueT> inline bool isnan(const ValueT a) {
122+
return sycl::isnan(a);
123+
}
124+
#ifdef SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS
125+
inline bool isnan(const sycl::ext::oneapi::bfloat16 a) {
126+
return sycl::ext::oneapi::experimental::isnan(a);
127+
}
128+
#endif
129+
121130
} // namespace detail
122131

123132
/// Compute fast_length for variable-length array
@@ -167,6 +176,121 @@ inline ValueT length(const ValueT *a, const int len) {
167176
}
168177
}
169178

179+
/// Performs comparison.
180+
/// \param [in] a The first value
181+
/// \param [in] b The second value
182+
/// \param [in] binary_op functor that implements the binary operation
183+
/// \returns the comparison result
184+
template <typename ValueT, class BinaryOperation>
185+
inline std::enable_if_t<
186+
std::is_same_v<std::invoke_result_t<BinaryOperation, ValueT, ValueT>, bool>,
187+
bool>
188+
compare(const ValueT a, const ValueT b, const BinaryOperation binary_op) {
189+
return binary_op(a, b);
190+
}
191+
template <typename ValueT>
192+
inline std::enable_if_t<
193+
std::is_same_v<std::invoke_result_t<std::not_equal_to<>, ValueT, ValueT>,
194+
bool>,
195+
bool>
196+
compare(const ValueT a, const ValueT b, const std::not_equal_to<> binary_op) {
197+
return !detail::isnan(a) && !detail::isnan(b) && binary_op(a, b);
198+
}
199+
200+
/// Performs 2 element comparison.
201+
/// \param [in] a The first value
202+
/// \param [in] b The second value
203+
/// \param [in] binary_op functor that implements the binary operation
204+
/// \returns the comparison result
205+
template <typename ValueT, class BinaryOperation>
206+
inline std::enable_if_t<ValueT::size() == 2, ValueT>
207+
compare(const ValueT a, const ValueT b, const BinaryOperation binary_op) {
208+
return {compare(a[0], b[0], binary_op), compare(a[1], b[1], binary_op)};
209+
}
210+
211+
/// Performs unordered comparison.
212+
/// \param [in] a The first value
213+
/// \param [in] b The second value
214+
/// \param [in] binary_op functor that implements the binary operation
215+
/// \returns the comparison result
216+
template <typename ValueT, class BinaryOperation>
217+
inline std::enable_if_t<
218+
std::is_same_v<std::invoke_result_t<BinaryOperation, ValueT, ValueT>, bool>,
219+
bool>
220+
unordered_compare(const ValueT a, const ValueT b,
221+
const BinaryOperation binary_op) {
222+
return detail::isnan(a) || detail::isnan(b) || binary_op(a, b);
223+
}
224+
225+
/// Performs 2 element unordered comparison.
226+
/// \param [in] a The first value
227+
/// \param [in] b The second value
228+
/// \param [in] binary_op functor that implements the binary operation
229+
/// \returns the comparison result
230+
template <typename ValueT, class BinaryOperation>
231+
inline std::enable_if_t<ValueT::size() == 2, ValueT>
232+
unordered_compare(const ValueT a, const ValueT b,
233+
const BinaryOperation binary_op) {
234+
return {unordered_compare(a[0], b[0], binary_op),
235+
unordered_compare(a[1], b[1], binary_op)};
236+
}
237+
238+
/// Performs 2 element comparison and return true if both results are true.
239+
/// \param [in] a The first value
240+
/// \param [in] b The second value
241+
/// \param [in] binary_op functor that implements the binary operation
242+
/// \returns the comparison result
243+
template <typename ValueT, class BinaryOperation>
244+
inline std::enable_if_t<ValueT::size() == 2, bool>
245+
compare_both(const ValueT a, const ValueT b, const BinaryOperation binary_op) {
246+
return compare(a[0], b[0], binary_op) && compare(a[1], b[1], binary_op);
247+
}
248+
249+
/// Performs 2 element unordered comparison and return true if both results are
250+
/// true.
251+
/// \param [in] a The first value
252+
/// \param [in] b The second value
253+
/// \param [in] binary_op functor that implements the binary operation
254+
/// \returns the comparison result
255+
template <typename ValueT, class BinaryOperation>
256+
inline std::enable_if_t<ValueT::size() == 2, bool>
257+
unordered_compare_both(const ValueT a, const ValueT b,
258+
const BinaryOperation binary_op) {
259+
return unordered_compare(a[0], b[0], binary_op) &&
260+
unordered_compare(a[1], b[1], binary_op);
261+
}
262+
263+
/// Performs 2 elements comparison, compare result of each element is 0 (false)
264+
/// or 0xffff (true), returns an unsigned int by composing compare result of two
265+
/// elements.
266+
/// \param [in] a The first value
267+
/// \param [in] b The second value
268+
/// \param [in] binary_op functor that implements the binary operation
269+
/// \returns the comparison result
270+
template <typename ValueT, class BinaryOperation>
271+
inline unsigned compare_mask(const sycl::vec<ValueT, 2> a,
272+
const sycl::vec<ValueT, 2> b,
273+
const BinaryOperation binary_op) {
274+
// Since compare returns 0 or 1, -compare will be 0x00000000 or 0xFFFFFFFF
275+
return ((-compare(a[0], b[0], binary_op)) << 16) |
276+
((-compare(a[1], b[1], binary_op)) & 0xFFFF);
277+
}
278+
279+
/// Performs 2 elements unordered comparison, compare result of each element is
280+
/// 0 (false) or 0xffff (true), returns an unsigned int by composing compare
281+
/// result of two elements.
282+
/// \param [in] a The first value
283+
/// \param [in] b The second value
284+
/// \param [in] binary_op functor that implements the binary operation
285+
/// \returns the comparison result
286+
template <typename ValueT, class BinaryOperation>
287+
inline unsigned unordered_compare_mask(const sycl::vec<ValueT, 2> a,
288+
const sycl::vec<ValueT, 2> b,
289+
const BinaryOperation binary_op) {
290+
return ((-unordered_compare(a[0], b[0], binary_op)) << 16) |
291+
((-unordered_compare(a[1], b[1], binary_op)) & 0xFFFF);
292+
}
293+
170294
/// Compute vectorized max for two values, with each value treated as a vector
171295
/// type \p S
172296
/// \param [in] S The type of the vector

sycl/test-e2e/syclcompat/common.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,5 @@ void instantiate_all_types(Functor &&f) {
4343
using value_type_list =
4444
std::tuple<int, unsigned int, short, unsigned short, long, unsigned long,
4545
long long, unsigned long long, float, double, sycl::half>;
46+
47+
using fp_type_list = std::tuple<float, double, sycl::half>;

0 commit comments

Comments
 (0)