Skip to content

[ESIMD] Enable some of intrinsic math ops on wrapper element types. #5271

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions sycl/include/CL/sycl/builtins_esimd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ cos(__ESIMD_NS::simd<float, SZ> x) __NOEXC {
#ifdef __SYCL_DEVICE_ONLY__
return __ESIMD_NS::detail::ocl_cos<SZ>(x.data());
#else
return __esimd_cos<SZ>(x.data());
return __esimd_cos<float, SZ>(x.data());
#endif // __SYCL_DEVICE_ONLY__
}

Expand All @@ -39,7 +39,7 @@ sin(__ESIMD_NS::simd<float, SZ> x) __NOEXC {
#ifdef __SYCL_DEVICE_ONLY__
return __ESIMD_NS::detail::ocl_sin<SZ>(x.data());
#else
return __esimd_sin<SZ>(x.data());
return __esimd_sin<float, SZ>(x.data());
#endif // __SYCL_DEVICE_ONLY__
}

Expand All @@ -50,7 +50,7 @@ exp(__ESIMD_NS::simd<float, SZ> x) __NOEXC {
#ifdef __SYCL_DEVICE_ONLY__
return __ESIMD_NS::detail::ocl_exp<SZ>(x.data());
#else
return __esimd_exp<SZ>(x.data());
return __esimd_exp<float, SZ>(x.data());
#endif // __SYCL_DEVICE_ONLY__
}

Expand All @@ -61,7 +61,7 @@ log(__ESIMD_NS::simd<float, SZ> x) __NOEXC {
#ifdef __SYCL_DEVICE_ONLY__
return __ESIMD_NS::detail::ocl_log<SZ>(x.data());
#else
return __esimd_log<SZ>(x.data());
return __esimd_log<float, SZ>(x.data());
#endif // __SYCL_DEVICE_ONLY__
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ template <class T, class SFINAE> struct element_type_traits {
// Whether a value or clang vector value the raw element type can be used
// directly as operand to std C++ operations.
static inline constexpr bool use_native_cpp_ops = true;
// W/A for MSVC compiler problems which thinks
// std::is_floating_point_v<_Float16> is false; so require new element types
// implementations to state "is floating point" trait explicitly
static inline constexpr bool is_floating_point = false;
};

// Element type traits specialization for C++ standard element type.
Expand All @@ -141,8 +145,19 @@ struct element_type_traits<T, std::enable_if_t<is_vectorizable_v<T>>> {
using RawT = T;
using EnclosingCppT = T;
static inline constexpr bool use_native_cpp_ops = true;
static inline constexpr bool is_floating_point = std::is_floating_point_v<T>;
};

#ifdef __SYCL_DEVICE_ONLY__
template <> struct element_type_traits<_Float16, void> {
using RawT = _Float16;
using EnclosingCppT = _Float16;
__SYCL_DEPRECATED("use sycl::half as element type")
static inline constexpr bool use_native_cpp_ops = true;
static inline constexpr bool is_floating_point = true;
};
#endif

// --- Type conversions

// Low-level conversion functions to and from a wrapper element type.
Expand Down Expand Up @@ -563,7 +578,7 @@ class WrapperElementTypeProxy {
// the wrapper floating-point types such as sycl::half.
template <typename T>
static inline constexpr bool is_generic_floating_point_v =
std::is_floating_point_v<typename element_type_traits<T>::EnclosingCppT>;
element_type_traits<T>::is_floating_point;

// @{
// Get computation type of a binary operator given its operand types:
Expand Down Expand Up @@ -664,6 +679,8 @@ struct element_type_traits<T, std::enable_if_t<std::is_same_v<T, sycl::half>>> {
// operations on half type.
static inline constexpr bool use_native_cpp_ops = false;
#endif // __SYCL_DEVICE_ONLY__

static inline constexpr bool is_floating_point = true;
};

using half_raw = __raw_t<sycl::half>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@

#ifndef __SYCL_DEVICE_ONLY__

#include <assert.h>
#include <limits>

#include <sycl/ext/intel/experimental/esimd/detail/elem_type_traits.hpp>

#define SIMDCF_ELEMENT_SKIP(i)

__SYCL_INLINE_NAMESPACE(cl) {
Expand Down Expand Up @@ -46,7 +49,12 @@ static long long abs(long long a) {
}
}

template <typename RT> struct satur {
template <typename RT, class SFINAE = void> struct satur;

template <typename RT>
struct satur<RT, std::enable_if_t<std::is_integral_v<RT>>> {
static_assert(!__SEIEED::is_wrapper_elem_type_v<RT>);

template <typename T> static RT saturate(const T val, const int flags) {
if ((flags & sat_is_on) == 0) {
return (RT)val;
Expand All @@ -72,35 +80,29 @@ template <typename RT> struct satur {
}
};

template <> struct satur<float> {
template <typename T> static float saturate(const T val, const int flags) {
if ((flags & sat_is_on) == 0) {
return (float)val;
}

if (val < 0.) {
return 0;
} else if (val > 1.) {
return 1.;
} else {
return (float)val;
}
}
};

template <> struct satur<double> {
template <typename T> static double saturate(const T val, const int flags) {
if ((flags & sat_is_on) == 0) {
return (double)val;
// Host implemenation of saturation for FP types, including non-standarad
// wrapper types such as sycl::half. Template parameters are defined in terms
// of user-level types (sycl::half), function parameter and return types -
// in terms of raw bit representation type(_Float16 for half on device).
template <class Tdst>
struct satur<Tdst,
std::enable_if_t<__SEIEED::is_generic_floating_point_v<Tdst>>> {
template <typename Tsrc>
static __SEIEED::__raw_t<Tdst> saturate(const __SEIEED::__raw_t<Tsrc> raw_src,
const int flags) {
Tsrc src = __SEIEED::bitcast_to_wrapper_type<Tsrc>(raw_src);

// perform comparison on user type!
if ((flags & sat_is_on) == 0 || (src >= 0 && src <= 1)) {
// convert_scalar accepts/returns user types - need to bitcast
Tdst dst = __SEIEED::convert_scalar<Tdst, Tsrc>(src);
return __SEIEED::bitcast_to_raw_type<Tdst>(dst);
}

if (val < 0.) {
return 0;
} else if (val > 1.) {
return 1.;
} else {
return (double)val;
if (src < 0) {
return __SEIEED::bitcast_to_raw_type<Tdst>(Tdst{0});
}
assert(src > 1);
return __SEIEED::bitcast_to_raw_type<Tdst>(Tdst{1});
}
};

Expand All @@ -116,6 +118,10 @@ template <> struct SetSatur<double, true> {
static unsigned int set() { return sat_is_on; }
};

// TODO replace restype_ex with detail::computation_type_t and represent half
// as sycl::half rather than 'using half = sycl::detail::half_impl::half;'
// above

// used for intermediate type in dp4a emulation
template <typename T1, typename T2> struct restype_ex {
private:
Expand Down Expand Up @@ -430,36 +436,6 @@ template <> struct fptype<float> { static const bool value = true; };
template <typename T> struct dftype { static const bool value = false; };
template <> struct dftype<double> { static const bool value = true; };

template <typename T> struct esimdtype;
template <> struct esimdtype<char> { static const bool value = true; };

template <> struct esimdtype<signed char> { static const bool value = true; };

template <> struct esimdtype<unsigned char> { static const bool value = true; };

template <> struct esimdtype<short> { static const bool value = true; };

template <> struct esimdtype<unsigned short> {
static const bool value = true;
};
template <> struct esimdtype<int> { static const bool value = true; };

template <> struct esimdtype<unsigned int> { static const bool value = true; };

template <> struct esimdtype<unsigned long> { static const bool value = true; };

template <> struct esimdtype<half> { static const bool value = true; };

template <> struct esimdtype<float> { static const bool value = true; };

template <> struct esimdtype<double> { static const bool value = true; };

template <> struct esimdtype<long long> { static const bool value = true; };

template <> struct esimdtype<unsigned long long> {
static const bool value = true;
};

template <typename T> struct bytetype;
template <> struct bytetype<char> { static const bool value = true; };
template <> struct bytetype<unsigned char> { static const bool value = true; };
Expand Down
Loading