Skip to content

[SYCL][NFC] Fix duplication caused by functors defined in ext::oneapi… #4361

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions sycl/include/CL/sycl/group_algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,9 @@ get_local_linear_id<ext::oneapi::sub_group>(ext::oneapi::sub_group g) {
// ---- is_native_op
template <typename T>
using native_op_list =
type_list<ext::oneapi::plus<T>, ext::oneapi::bit_or<T>,
ext::oneapi::bit_xor<T>, ext::oneapi::bit_and<T>,
ext::oneapi::maximum<T>, ext::oneapi::minimum<T>,
ext::oneapi::multiplies<T>, sycl::plus<T>, sycl::bit_or<T>,
sycl::bit_xor<T>, sycl::bit_and<T>, sycl::maximum<T>,
sycl::minimum<T>, sycl::multiplies<T>>;
type_list<sycl::plus<T>, sycl::bit_or<T>, sycl::bit_xor<T>,
sycl::bit_and<T>, sycl::maximum<T>, sycl::minimum<T>,
sycl::multiplies<T>>;

template <typename T, typename BinaryOperation> struct is_native_op {
static constexpr bool value =
Expand Down
56 changes: 21 additions & 35 deletions sycl/include/CL/sycl/known_identity.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,53 +18,39 @@ namespace sycl {
namespace detail {

template <typename T, class BinaryOperation>
using IsPlus = bool_constant<
std::is_same<BinaryOperation, sycl::plus<T>>::value ||
std::is_same<BinaryOperation, sycl::plus<void>>::value ||
std::is_same<BinaryOperation, ext::oneapi::plus<T>>::value ||
std::is_same<BinaryOperation, ext::oneapi::plus<void>>::value>;
using IsPlus =
bool_constant<std::is_same<BinaryOperation, sycl::plus<T>>::value ||
std::is_same<BinaryOperation, sycl::plus<void>>::value>;

template <typename T, class BinaryOperation>
using IsMultiplies = bool_constant<
std::is_same<BinaryOperation, sycl::multiplies<T>>::value ||
std::is_same<BinaryOperation, sycl::multiplies<void>>::value ||
std::is_same<BinaryOperation, ext::oneapi::multiplies<T>>::value ||
std::is_same<BinaryOperation, ext::oneapi::multiplies<void>>::value>;
using IsMultiplies =
bool_constant<std::is_same<BinaryOperation, sycl::multiplies<T>>::value ||
std::is_same<BinaryOperation, sycl::multiplies<void>>::value>;

template <typename T, class BinaryOperation>
using IsMinimum = bool_constant<
std::is_same<BinaryOperation, sycl::minimum<T>>::value ||
std::is_same<BinaryOperation, sycl::minimum<void>>::value ||
std::is_same<BinaryOperation, ext::oneapi::minimum<T>>::value ||
std::is_same<BinaryOperation, ext::oneapi::minimum<void>>::value>;
using IsMinimum =
bool_constant<std::is_same<BinaryOperation, sycl::minimum<T>>::value ||
std::is_same<BinaryOperation, sycl::minimum<void>>::value>;

template <typename T, class BinaryOperation>
using IsMaximum = bool_constant<
std::is_same<BinaryOperation, sycl::maximum<T>>::value ||
std::is_same<BinaryOperation, sycl::maximum<void>>::value ||
std::is_same<BinaryOperation, ext::oneapi::maximum<T>>::value ||
std::is_same<BinaryOperation, ext::oneapi::maximum<void>>::value>;
using IsMaximum =
bool_constant<std::is_same<BinaryOperation, sycl::maximum<T>>::value ||
std::is_same<BinaryOperation, sycl::maximum<void>>::value>;

template <typename T, class BinaryOperation>
using IsBitOR = bool_constant<
std::is_same<BinaryOperation, sycl::bit_or<T>>::value ||
std::is_same<BinaryOperation, sycl::bit_or<void>>::value ||
std::is_same<BinaryOperation, ext::oneapi::bit_or<T>>::value ||
std::is_same<BinaryOperation, ext::oneapi::bit_or<void>>::value>;
using IsBitOR =
bool_constant<std::is_same<BinaryOperation, sycl::bit_or<T>>::value ||
std::is_same<BinaryOperation, sycl::bit_or<void>>::value>;

template <typename T, class BinaryOperation>
using IsBitXOR = bool_constant<
std::is_same<BinaryOperation, sycl::bit_xor<T>>::value ||
std::is_same<BinaryOperation, sycl::bit_xor<void>>::value ||
std::is_same<BinaryOperation, ext::oneapi::bit_xor<T>>::value ||
std::is_same<BinaryOperation, ext::oneapi::bit_xor<void>>::value>;
using IsBitXOR =
bool_constant<std::is_same<BinaryOperation, sycl::bit_xor<T>>::value ||
std::is_same<BinaryOperation, sycl::bit_xor<void>>::value>;

template <typename T, class BinaryOperation>
using IsBitAND = bool_constant<
std::is_same<BinaryOperation, sycl::bit_and<T>>::value ||
std::is_same<BinaryOperation, sycl::bit_and<void>>::value ||
std::is_same<BinaryOperation, ext::oneapi::bit_and<T>>::value ||
std::is_same<BinaryOperation, ext::oneapi::bit_and<void>>::value>;
using IsBitAND =
bool_constant<std::is_same<BinaryOperation, sycl::bit_and<T>>::value ||
std::is_same<BinaryOperation, sycl::bit_and<void>>::value>;

// Identity = 0
template <typename T, class BinaryOperation>
Expand Down
81 changes: 18 additions & 63 deletions sycl/include/sycl/ext/oneapi/functional.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,46 +16,13 @@ namespace sycl {
namespace ext {
namespace oneapi {

template <typename T = void> struct minimum {
T operator()(const T &lhs, const T &rhs) const {
return std::less<T>()(lhs, rhs) ? lhs : rhs;
}
};

template <> struct minimum<void> {
struct is_transparent {};
template <typename T, typename U>
auto operator()(T &&lhs, U &&rhs) const ->
typename std::common_type<T &&, U &&>::type {
return std::less<>()(std::forward<const T>(lhs), std::forward<const U>(rhs))
? std::forward<T>(lhs)
: std::forward<U>(rhs);
}
};

template <typename T = void> struct maximum {
T operator()(const T &lhs, const T &rhs) const {
return std::greater<T>()(lhs, rhs) ? lhs : rhs;
}
};

template <> struct maximum<void> {
struct is_transparent {};
template <typename T, typename U>
auto operator()(T &&lhs, U &&rhs) const ->
typename std::common_type<T &&, U &&>::type {
return std::greater<>()(std::forward<const T>(lhs),
std::forward<const U>(rhs))
? std::forward<T>(lhs)
: std::forward<U>(rhs);
}
};

template <typename T = void> using plus = std::plus<T>;
template <typename T = void> using multiplies = std::multiplies<T>;
template <typename T = void> using bit_or = std::bit_or<T>;
template <typename T = void> using bit_xor = std::bit_xor<T>;
template <typename T = void> using bit_and = std::bit_and<T>;
template <typename T = void> using maximum = sycl::maximum<T>;
template <typename T = void> using minimum = sycl::minimum<T>;

} // namespace oneapi
} // namespace ext
Expand Down Expand Up @@ -106,41 +73,29 @@ struct GroupOpTag<T, detail::enable_if_t<detail::is_sgenfloat<T>::value>> {
return Ret; \
}

// calc for sycl minimum/maximum function objects
// calc for sycl function objects
__SYCL_CALC_OVERLOAD(GroupOpISigned, SMin, sycl::minimum<T>)
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, UMin, sycl::minimum<T>)
__SYCL_CALC_OVERLOAD(GroupOpFP, FMin, sycl::minimum<T>)

__SYCL_CALC_OVERLOAD(GroupOpISigned, SMax, sycl::maximum<T>)
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, UMax, sycl::maximum<T>)
__SYCL_CALC_OVERLOAD(GroupOpFP, FMax, sycl::maximum<T>)

// calc for oneapi function objects
__SYCL_CALC_OVERLOAD(GroupOpISigned, SMin, ext::oneapi::minimum<T>)
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, UMin, ext::oneapi::minimum<T>)
__SYCL_CALC_OVERLOAD(GroupOpFP, FMin, ext::oneapi::minimum<T>)
__SYCL_CALC_OVERLOAD(GroupOpISigned, SMax, ext::oneapi::maximum<T>)
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, UMax, ext::oneapi::maximum<T>)
__SYCL_CALC_OVERLOAD(GroupOpFP, FMax, ext::oneapi::maximum<T>)
__SYCL_CALC_OVERLOAD(GroupOpISigned, IAdd, ext::oneapi::plus<T>)
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, IAdd, ext::oneapi::plus<T>)
__SYCL_CALC_OVERLOAD(GroupOpFP, FAdd, ext::oneapi::plus<T>)

__SYCL_CALC_OVERLOAD(GroupOpISigned, NonUniformIMul, ext::oneapi::multiplies<T>)
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, NonUniformIMul,
ext::oneapi::multiplies<T>)
__SYCL_CALC_OVERLOAD(GroupOpFP, NonUniformFMul, ext::oneapi::multiplies<T>)
__SYCL_CALC_OVERLOAD(GroupOpISigned, NonUniformBitwiseOr,
ext::oneapi::bit_or<T>)
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, NonUniformBitwiseOr,
ext::oneapi::bit_or<T>)
__SYCL_CALC_OVERLOAD(GroupOpISigned, NonUniformBitwiseXor,
ext::oneapi::bit_xor<T>)
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, NonUniformBitwiseXor,
ext::oneapi::bit_xor<T>)
__SYCL_CALC_OVERLOAD(GroupOpISigned, NonUniformBitwiseAnd,
ext::oneapi::bit_and<T>)
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, NonUniformBitwiseAnd,
ext::oneapi::bit_and<T>)
__SYCL_CALC_OVERLOAD(GroupOpISigned, IAdd, sycl::plus<T>)
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, IAdd, sycl::plus<T>)
__SYCL_CALC_OVERLOAD(GroupOpFP, FAdd, sycl::plus<T>)

__SYCL_CALC_OVERLOAD(GroupOpISigned, NonUniformIMul, sycl::multiplies<T>)
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, NonUniformIMul, sycl::multiplies<T>)
__SYCL_CALC_OVERLOAD(GroupOpFP, NonUniformFMul, sycl::multiplies<T>)

__SYCL_CALC_OVERLOAD(GroupOpISigned, NonUniformBitwiseOr, sycl::bit_or<T>)
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, NonUniformBitwiseOr, sycl::bit_or<T>)
__SYCL_CALC_OVERLOAD(GroupOpISigned, NonUniformBitwiseXor, sycl::bit_xor<T>)
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, NonUniformBitwiseXor, sycl::bit_xor<T>)
__SYCL_CALC_OVERLOAD(GroupOpISigned, NonUniformBitwiseAnd, sycl::bit_and<T>)
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, NonUniformBitwiseAnd, sycl::bit_and<T>)

#undef __SYCL_CALC_OVERLOAD

Expand Down
8 changes: 3 additions & 5 deletions sycl/include/sycl/ext/oneapi/reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class reducer {
/// using those operations, which are based on functionality provided by
/// sycl::atomic class.
///
/// For example, it is known that 0 is identity for ext::oneapi::plus operations
/// For example, it is known that 0 is identity for sycl::plus operations
/// accepting native scalar types to which scalar 0 is convertible.
/// Also, for int32/64 types the atomic_combine() is lowered to
/// sycl::atomic::fetch_add().
Expand Down Expand Up @@ -317,8 +317,7 @@ class reducer<T, BinaryOperation,
.fetch_and(MValue);
}

/// Atomic MIN operation: *ReduVarPtr = ext::oneapi::minimum(*ReduVarPtr,
/// MValue);
/// Atomic MIN operation: *ReduVarPtr = sycl::minimum(*ReduVarPtr, MValue);
template <access::address_space Space = access::address_space::global_space,
typename _T = T, class _BinaryOperation = BinaryOperation>
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
Expand All @@ -332,8 +331,7 @@ class reducer<T, BinaryOperation,
.fetch_min(MValue);
}

/// Atomic MAX operation: *ReduVarPtr = ext::oneapi::maximum(*ReduVarPtr,
/// MValue);
/// Atomic MAX operation: *ReduVarPtr = sycl::maximum(*ReduVarPtr, MValue);
template <access::address_space Space = access::address_space::global_space,
typename _T = T, class _BinaryOperation = BinaryOperation>
enable_if_t<std::is_same<typename remove_AS<_T>::type, T>::value &&
Expand Down
7 changes: 4 additions & 3 deletions sycl/source/detail/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,10 @@ __SYCL_EXPORT uint32_t reduGetMaxNumConcurrentWorkGroups(
std::shared_ptr<sycl::detail::queue_impl> Queue) {
device Dev = Queue->get_device();
uint32_t NumThreads = Dev.get_info<info::device::max_compute_units>();
// The heuristics require additional tuning for various devices and vendors.
// For now assuming that each of execution units have about 8 working threads
// gives good results on some known/supported GPU devices.
// TODO: The heuristics here require additional tuning for various devices
// and vendors. For now this code assumes that execution units have about
// 8 working threads, which gives good results on some known/supported
// GPU devices.
if (Dev.is_gpu())
NumThreads *= 8;
return NumThreads;
Expand Down
10 changes: 5 additions & 5 deletions sycl/test/basic_tests/reduction_known_identity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
using namespace cl::sycl;

template <typename T> void checkCommonBasicKnownIdentity() {
static_assert(has_known_identity<ext::oneapi::maximum<>, T>::value);
static_assert(has_known_identity<ext::oneapi::maximum<T>, T>::value);
static_assert(has_known_identity<ext::oneapi::minimum<>, T>::value);
static_assert(has_known_identity<ext::oneapi::minimum<T>, T>::value);
static_assert(has_known_identity<sycl::maximum<>, T>::value);
static_assert(has_known_identity<sycl::maximum<T>, T>::value);
static_assert(has_known_identity<sycl::minimum<>, T>::value);
static_assert(has_known_identity<sycl::minimum<T>, T>::value);
}

template <typename T> void checkCommonKnownIdentity() {
Expand Down Expand Up @@ -100,7 +100,7 @@ int main() {

// Few negative tests just to check that it does not always return true.
static_assert(!has_known_identity<std::minus<>, int>::value);
static_assert(!has_known_identity<ext::oneapi::bit_or<>, float>::value);
static_assert(!has_known_identity<sycl::bit_or<>, float>::value);

return 0;
}