Skip to content

[SYCL] Provide SYCL 2020 function objects #3868

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jun 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sycl/include/CL/sycl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <CL/sycl/event.hpp>
#include <CL/sycl/exception.hpp>
#include <CL/sycl/feature_test.hpp>
#include <CL/sycl/functional.hpp>
#include <CL/sycl/group.hpp>
#include <CL/sycl/group_algorithm.hpp>
#include <CL/sycl/group_local_memory.hpp>
Expand Down
11 changes: 11 additions & 0 deletions sycl/include/CL/sycl/ONEAPI/functional.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
//===----------------------------------------------------------------------===//

#pragma once
#include <CL/sycl/functional.hpp>

#include <functional>

__SYCL_INLINE_NAMESPACE(cl) {
Expand Down Expand Up @@ -90,6 +92,15 @@ struct GroupOpTag<T, detail::enable_if_t<detail::is_sgenfloat<T>::value>> {
return Ret; \
}

// calc for sycl minimum/maximum function objects
__SYCL_CALC_OVERLOAD(GroupOpISigned, SMin, sycl::minimum<T>)
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, UMin, sycl::minimum<T>)
__SYCL_CALC_OVERLOAD(GroupOpFP, FMin, sycl::minimum<T>)
__SYCL_CALC_OVERLOAD(GroupOpISigned, SMax, sycl::maximum<T>)
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, UMax, sycl::maximum<T>)
__SYCL_CALC_OVERLOAD(GroupOpFP, FMax, sycl::maximum<T>)

// calc for ONEAPI function objects
__SYCL_CALC_OVERLOAD(GroupOpISigned, SMin, ONEAPI::minimum<T>)
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, UMin, ONEAPI::minimum<T>)
__SYCL_CALC_OVERLOAD(GroupOpFP, FMin, ONEAPI::minimum<T>)
Expand Down
57 changes: 57 additions & 0 deletions sycl/include/CL/sycl/functional.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
//==----------- functional.hpp --- SYCL functional -------------------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#pragma once
#include <functional>

__SYCL_INLINE_NAMESPACE(cl) {
namespace sycl {

template <typename T = void> using plus = std::plus<T>;
template <typename T = void> using multiplies = std::multiplies<T>;
template <typename T = void> using bit_or = std::bit_or<T>;
template <typename T = void> using bit_xor = std::bit_xor<T>;
template <typename T = void> using bit_and = std::bit_and<T>;

template <typename T = void> struct minimum {
T operator()(const T &lhs, const T &rhs) const {
return std::less<T>()(lhs, rhs) ? lhs : rhs;
}
};

template <> struct minimum<void> {
struct is_transparent {};
template <typename T, typename U>
auto operator()(T &&lhs, U &&rhs) const ->
typename std::common_type<T &&, U &&>::type {
return std::less<>()(std::forward<const T>(lhs), std::forward<const U>(rhs))
? std::forward<T>(lhs)
: std::forward<U>(rhs);
}
};

template <typename T = void> struct maximum {
T operator()(const T &lhs, const T &rhs) const {
return std::greater<T>()(lhs, rhs) ? lhs : rhs;
}
};

template <> struct maximum<void> {
struct is_transparent {};
template <typename T, typename U>
auto operator()(T &&lhs, U &&rhs) const ->
typename std::common_type<T &&, U &&>::type {
return std::greater<>()(std::forward<const T>(lhs),
std::forward<const U>(rhs))
? std::forward<T>(lhs)
: std::forward<U>(rhs);
}
};

} // namespace sycl
} // __SYCL_INLINE_NAMESPACE(cl)
5 changes: 4 additions & 1 deletion sycl/include/CL/sycl/group_algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <CL/sycl/ONEAPI/functional.hpp>
#include <CL/sycl/detail/spirv.hpp>
#include <CL/sycl/detail/type_traits.hpp>
#include <CL/sycl/functional.hpp>
#include <CL/sycl/group.hpp>
#include <CL/sycl/known_identity.hpp>
#include <CL/sycl/nd_item.hpp>
Expand Down Expand Up @@ -86,7 +87,9 @@ template <typename T>
using native_op_list =
type_list<ONEAPI::plus<T>, ONEAPI::bit_or<T>, ONEAPI::bit_xor<T>,
ONEAPI::bit_and<T>, ONEAPI::maximum<T>, ONEAPI::minimum<T>,
ONEAPI::multiplies<T>>;
ONEAPI::multiplies<T>, sycl::plus<T>, sycl::bit_or<T>,
sycl::bit_xor<T>, sycl::bit_and<T>, sycl::maximum<T>,
sycl::minimum<T>, sycl::multiplies<T>>;

template <typename T, typename BinaryOperation> struct is_native_op {
static constexpr bool value =
Expand Down
26 changes: 20 additions & 6 deletions sycl/include/CL/sycl/known_identity.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,51 @@ namespace detail {

template <typename T, class BinaryOperation>
using IsPlus =
bool_constant<std::is_same<BinaryOperation, ONEAPI::plus<T>>::value ||
bool_constant<std::is_same<BinaryOperation, sycl::plus<T>>::value ||
std::is_same<BinaryOperation, sycl::plus<void>>::value ||
std::is_same<BinaryOperation, ONEAPI::plus<T>>::value ||
std::is_same<BinaryOperation, ONEAPI::plus<void>>::value>;

template <typename T, class BinaryOperation>
using IsMultiplies = bool_constant<
std::is_same<BinaryOperation, sycl::multiplies<T>>::value ||
std::is_same<BinaryOperation, sycl::multiplies<void>>::value ||
std::is_same<BinaryOperation, ONEAPI::multiplies<T>>::value ||
std::is_same<BinaryOperation, ONEAPI::multiplies<void>>::value>;

template <typename T, class BinaryOperation>
using IsMinimum =
bool_constant<std::is_same<BinaryOperation, ONEAPI::minimum<T>>::value ||
bool_constant<std::is_same<BinaryOperation, sycl::minimum<T>>::value ||
std::is_same<BinaryOperation, sycl::minimum<void>>::value ||
std::is_same<BinaryOperation, ONEAPI::minimum<T>>::value ||
std::is_same<BinaryOperation, ONEAPI::minimum<void>>::value>;

template <typename T, class BinaryOperation>
using IsMaximum =
bool_constant<std::is_same<BinaryOperation, ONEAPI::maximum<T>>::value ||
bool_constant<std::is_same<BinaryOperation, sycl::maximum<T>>::value ||
std::is_same<BinaryOperation, sycl::maximum<void>>::value ||
std::is_same<BinaryOperation, ONEAPI::maximum<T>>::value ||
std::is_same<BinaryOperation, ONEAPI::maximum<void>>::value>;

template <typename T, class BinaryOperation>
using IsBitOR =
bool_constant<std::is_same<BinaryOperation, ONEAPI::bit_or<T>>::value ||
bool_constant<std::is_same<BinaryOperation, sycl::bit_or<T>>::value ||
std::is_same<BinaryOperation, sycl::bit_or<void>>::value ||
std::is_same<BinaryOperation, ONEAPI::bit_or<T>>::value ||
std::is_same<BinaryOperation, ONEAPI::bit_or<void>>::value>;

template <typename T, class BinaryOperation>
using IsBitXOR =
bool_constant<std::is_same<BinaryOperation, ONEAPI::bit_xor<T>>::value ||
bool_constant<std::is_same<BinaryOperation, sycl::bit_xor<T>>::value ||
std::is_same<BinaryOperation, sycl::bit_xor<void>>::value ||
std::is_same<BinaryOperation, ONEAPI::bit_xor<T>>::value ||
std::is_same<BinaryOperation, ONEAPI::bit_xor<void>>::value>;

template <typename T, class BinaryOperation>
using IsBitAND =
bool_constant<std::is_same<BinaryOperation, ONEAPI::bit_and<T>>::value ||
bool_constant<std::is_same<BinaryOperation, sycl::bit_and<T>>::value ||
std::is_same<BinaryOperation, sycl::bit_and<void>>::value ||
std::is_same<BinaryOperation, ONEAPI::bit_and<T>>::value ||
std::is_same<BinaryOperation, ONEAPI::bit_and<void>>::value>;

// Identity = 0
Expand Down
6 changes: 3 additions & 3 deletions sycl/test/extensions/group-algorithm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,13 @@ int main() {
std::iota(input.begin(), input.end(), 0);
std::fill(output.begin(), output.end(), 0);

test<class KernelNamePlusV>(q, input, output, plus<>(), 0, GeZero());
test<class KernelNameMinimumV>(q, input, output, minimum<>(),
test<class KernelNamePlusV>(q, input, output, ONEAPI::plus<>(), 0, GeZero());
test<class KernelNameMinimumV>(q, input, output, ONEAPI::minimum<>(),
std::numeric_limits<int>::max(), IsEven());

#ifdef SPIRV_1_3
test<class KernelName_WonwuUVPUPOTKRKIBtT>(q, input, output,
multiplies<int>(), 1, LtZero());
ONEAPI::multiplies<int>(), 1, LtZero());
#endif // SPIRV_1_3

std::cout << "Test passed." << std::endl;
Expand Down
6 changes: 3 additions & 3 deletions sycl/test/on-device/back_to_back_collectives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ int main() {
auto g = it.get_group();
// Loop to increase number of back-to-back calls
for (int r = 0; r < 10; ++r) {
Sum[i] = reduce(g, Input[i], plus<>());
EScan[i] = exclusive_scan(g, Input[i], plus<>());
IScan[i] = inclusive_scan(g, Input[i], plus<>());
Sum[i] = reduce(g, Input[i], sycl::plus<>());
EScan[i] = exclusive_scan(g, Input[i], sycl::plus<>());
IScan[i] = inclusive_scan(g, Input[i], sycl::plus<>());
}
});
});
Expand Down
20 changes: 10 additions & 10 deletions sycl/test/on-device/group_algorithms_sycl2020/exclusive_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,24 +138,24 @@ int main() {
std::iota(input.begin(), input.end(), 0);
std::fill(output.begin(), output.end(), 0);

test<class KernelNamePlusV>(q, input, output, std::plus<>(), 0);
test<class KernelNameMinimumV>(q, input, output, sycl::ONEAPI::minimum<>(),
test<class KernelNamePlusV>(q, input, output, sycl::plus<>(), 0);
test<class KernelNameMinimumV>(q, input, output, sycl::minimum<>(),
std::numeric_limits<int>::max());
test<class KernelNameMaximumV>(q, input, output, sycl::ONEAPI::maximum<>(),
test<class KernelNameMaximumV>(q, input, output, sycl::maximum<>(),
std::numeric_limits<int>::lowest());

test<class KernelNamePlusI>(q, input, output, std::plus<int>(), 0);
test<class KernelNameMinimumI>(q, input, output, sycl::ONEAPI::minimum<int>(),
test<class KernelNamePlusI>(q, input, output, sycl::plus<int>(), 0);
test<class KernelNameMinimumI>(q, input, output, sycl::minimum<int>(),
std::numeric_limits<int>::max());
test<class KernelNameMaximumI>(q, input, output, sycl::ONEAPI::maximum<int>(),
test<class KernelNameMaximumI>(q, input, output, sycl::maximum<int>(),
std::numeric_limits<int>::lowest());

#ifdef SPIRV_1_3
test<class KernelName_VzAPutpBRRJrQPB>(q, input, output, multiplies<int>(),
test<class KernelName_VzAPutpBRRJrQPB>(q, input, output, sycl::multiplies<int>(),
1);
test<class KernelName_UXdGbr>(q, input, output, bit_or<int>(), 0);
test<class KernelName_saYaodNyJknrPW>(q, input, output, bit_xor<int>(), 0);
test<class KernelName_GPcuAlvAOjrDyP>(q, input, output, bit_and<int>(), ~0);
test<class KernelName_UXdGbr>(q, input, output, sycl::bit_or<int>(), 0);
test<class KernelName_saYaodNyJknrPW>(q, input, output, sycl::bit_xor<int>(), 0);
test<class KernelName_GPcuAlvAOjrDyP>(q, input, output, sycl::bit_and<int>(), ~0);
#endif // SPIRV_1_3

std::cout << "Test passed." << std::endl;
Expand Down
20 changes: 10 additions & 10 deletions sycl/test/on-device/group_algorithms_sycl2020/inclusive_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,25 +138,25 @@ int main() {
std::iota(input.begin(), input.end(), 0);
std::fill(output.begin(), output.end(), 0);

test<class KernelNamePlusV>(q, input, output, std::plus<>(), 0);
test<class KernelNameMinimumV>(q, input, output, sycl::ONEAPI::minimum<>(),
test<class KernelNamePlusV>(q, input, output, sycl::plus<>(), 0);
test<class KernelNameMinimumV>(q, input, output, sycl::minimum<>(),
std::numeric_limits<int>::max());
test<class KernelNameMaximumV>(q, input, output, sycl::ONEAPI::maximum<>(),
test<class KernelNameMaximumV>(q, input, output, sycl::maximum<>(),
std::numeric_limits<int>::lowest());

test<class KernelNamePlusI>(q, input, output, std::plus<int>(), 0);
test<class KernelNameMinimumI>(q, input, output, sycl::ONEAPI::minimum<int>(),
test<class KernelNamePlusI>(q, input, output, sycl::plus<int>(), 0);
test<class KernelNameMinimumI>(q, input, output, sycl::minimum<int>(),
std::numeric_limits<int>::max());
test<class KernelNameMaximumI>(q, input, output, sycl::ONEAPI::maximum<int>(),
test<class KernelNameMaximumI>(q, input, output, sycl::maximum<int>(),
std::numeric_limits<int>::lowest());

#ifdef SPIRV_1_3
test<class KernelName_zMyjxUrBgeUGoxmDwhvJ>(q, input, output,
multiplies<int>(), 1);
test<class KernelName_SljjtroxNRaAXoVnT>(q, input, output, bit_or<int>(), 0);
test<class KernelName_yXIZfjwjxQGiPeQAnc>(q, input, output, bit_xor<int>(),
sycl::multiplies<int>(), 1);
test<class KernelName_SljjtroxNRaAXoVnT>(q, input, output, sycl::bit_or<int>(), 0);
test<class KernelName_yXIZfjwjxQGiPeQAnc>(q, input, output, sycl::bit_xor<int>(),
0);
test<class KernelName_xGnAnMYHvqekCk>(q, input, output, bit_and<int>(), ~0);
test<class KernelName_xGnAnMYHvqekCk>(q, input, output, sycl::bit_and<int>(), ~0);
#endif // SPIRV_1_3

std::cout << "Test passed." << std::endl;
Expand Down
20 changes: 10 additions & 10 deletions sycl/test/on-device/group_algorithms_sycl2020/reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,24 +74,24 @@ int main() {
std::iota(input.begin(), input.end(), 0);
std::fill(output.begin(), output.end(), 0);

test<class KernelNamePlusV>(q, input, output, std::plus<>(), 0);
test<class KernelNameMinimumV>(q, input, output, sycl::ONEAPI::minimum<>(),
test<class KernelNamePlusV>(q, input, output, sycl::plus<>(), 0);
test<class KernelNameMinimumV>(q, input, output, sycl::minimum<>(),
std::numeric_limits<int>::max());
test<class KernelNameMaximumV>(q, input, output, sycl::ONEAPI::maximum<>(),
test<class KernelNameMaximumV>(q, input, output, sycl::maximum<>(),
std::numeric_limits<int>::lowest());

test<class KernelNamePlusI>(q, input, output, std::plus<int>(), 0);
test<class KernelNameMinimumI>(q, input, output, sycl::ONEAPI::minimum<int>(),
test<class KernelNamePlusI>(q, input, output, sycl::plus<int>(), 0);
test<class KernelNameMinimumI>(q, input, output, sycl::minimum<int>(),
std::numeric_limits<int>::max());
test<class KernelNameMaximumI>(q, input, output, sycl::ONEAPI::maximum<int>(),
test<class KernelNameMaximumI>(q, input, output, sycl::maximum<int>(),
std::numeric_limits<int>::lowest());

#ifdef SPIRV_1_3
test<class KernelName_WonwuUVPUPOTKRKIBtT>(q, input, output,
multiplies<int>(), 1);
test<class KernelName_qYBaJDZTMGkdIwD>(q, input, output, bit_or<int>(), 0);
test<class KernelName_eLSFt>(q, input, output, bit_xor<int>(), 0);
test<class KernelName_uFhJnxSVhNAiFPTG>(q, input, output, bit_and<int>(), ~0);
sycl::multiplies<int>(), 1);
test<class KernelName_qYBaJDZTMGkdIwD>(q, input, output, sycl::bit_or<int>(), 0);
test<class KernelName_eLSFt>(q, input, output, sycl::bit_xor<int>(), 0);
test<class KernelName_uFhJnxSVhNAiFPTG>(q, input, output, sycl::bit_and<int>(), ~0);
#endif // SPIRV_1_3

std::cout << "Test passed." << std::endl;
Expand Down