Skip to content

Commit eb4b529

Browse files
kbobrovsv-klochkov
andauthored
[ESIMD] Enable some of intrinsic math ops on wrapper element types. (#5271)
* [ESIMD] Enable some of intrinsic math ops on wrapper element types. - Enable * inv, log, exp, sqrt, sqrt_ieee, rsqrt, sin, cos, pow, saturation on wrapper element types (sycl::half at the moment) - Fix host implementation to support wrapper types - TODO some intrinsics like rounding, inverse trigonometric functions are still not supported - will cause compile-time error Signed-off-by: Konstantin S Bobrovsky <[email protected]> Co-authored-by: Vyacheslav Klochkov <[email protected]>
1 parent 8d0ec2c commit eb4b529

File tree

9 files changed

+732
-715
lines changed

9 files changed

+732
-715
lines changed

sycl/include/CL/sycl/builtins_esimd.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ cos(__ESIMD_NS::simd<float, SZ> x) __NOEXC {
2828
#ifdef __SYCL_DEVICE_ONLY__
2929
return __ESIMD_NS::detail::ocl_cos<SZ>(x.data());
3030
#else
31-
return __esimd_cos<SZ>(x.data());
31+
return __esimd_cos<float, SZ>(x.data());
3232
#endif // __SYCL_DEVICE_ONLY__
3333
}
3434

@@ -39,7 +39,7 @@ sin(__ESIMD_NS::simd<float, SZ> x) __NOEXC {
3939
#ifdef __SYCL_DEVICE_ONLY__
4040
return __ESIMD_NS::detail::ocl_sin<SZ>(x.data());
4141
#else
42-
return __esimd_sin<SZ>(x.data());
42+
return __esimd_sin<float, SZ>(x.data());
4343
#endif // __SYCL_DEVICE_ONLY__
4444
}
4545

@@ -50,7 +50,7 @@ exp(__ESIMD_NS::simd<float, SZ> x) __NOEXC {
5050
#ifdef __SYCL_DEVICE_ONLY__
5151
return __ESIMD_NS::detail::ocl_exp<SZ>(x.data());
5252
#else
53-
return __esimd_exp<SZ>(x.data());
53+
return __esimd_exp<float, SZ>(x.data());
5454
#endif // __SYCL_DEVICE_ONLY__
5555
}
5656

@@ -61,7 +61,7 @@ log(__ESIMD_NS::simd<float, SZ> x) __NOEXC {
6161
#ifdef __SYCL_DEVICE_ONLY__
6262
return __ESIMD_NS::detail::ocl_log<SZ>(x.data());
6363
#else
64-
return __esimd_log<SZ>(x.data());
64+
return __esimd_log<float, SZ>(x.data());
6565
#endif // __SYCL_DEVICE_ONLY__
6666
}
6767

sycl/include/sycl/ext/intel/experimental/esimd/detail/elem_type_traits.hpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ template <class T, class SFINAE> struct element_type_traits {
133133
// Whether a value or clang vector value the raw element type can be used
134134
// directly as operand to std C++ operations.
135135
static inline constexpr bool use_native_cpp_ops = true;
136+
// W/A for MSVC compiler problems which thinks
137+
// std::is_floating_point_v<_Float16> is false; so require new element types
138+
// implementations to state "is floating point" trait explicitly
139+
static inline constexpr bool is_floating_point = false;
136140
};
137141

138142
// Element type traits specialization for C++ standard element type.
@@ -141,8 +145,19 @@ struct element_type_traits<T, std::enable_if_t<is_vectorizable_v<T>>> {
141145
using RawT = T;
142146
using EnclosingCppT = T;
143147
static inline constexpr bool use_native_cpp_ops = true;
148+
static inline constexpr bool is_floating_point = std::is_floating_point_v<T>;
144149
};
145150

151+
#ifdef __SYCL_DEVICE_ONLY__
152+
template <> struct element_type_traits<_Float16, void> {
153+
using RawT = _Float16;
154+
using EnclosingCppT = _Float16;
155+
__SYCL_DEPRECATED("use sycl::half as element type")
156+
static inline constexpr bool use_native_cpp_ops = true;
157+
static inline constexpr bool is_floating_point = true;
158+
};
159+
#endif
160+
146161
// --- Type conversions
147162

148163
// Low-level conversion functions to and from a wrapper element type.
@@ -563,7 +578,7 @@ class WrapperElementTypeProxy {
563578
// the wrapper floating-point types such as sycl::half.
564579
template <typename T>
565580
static inline constexpr bool is_generic_floating_point_v =
566-
std::is_floating_point_v<typename element_type_traits<T>::EnclosingCppT>;
581+
element_type_traits<T>::is_floating_point;
567582

568583
// @{
569584
// Get computation type of a binary operator given its operand types:
@@ -664,6 +679,8 @@ struct element_type_traits<T, std::enable_if_t<std::is_same_v<T, sycl::half>>> {
664679
// operations on half type.
665680
static inline constexpr bool use_native_cpp_ops = false;
666681
#endif // __SYCL_DEVICE_ONLY__
682+
683+
static inline constexpr bool is_floating_point = true;
667684
};
668685

669686
using half_raw = __raw_t<sycl::half>;

sycl/include/sycl/ext/intel/experimental/esimd/detail/host_util.hpp

Lines changed: 34 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212

1313
#ifndef __SYCL_DEVICE_ONLY__
1414

15+
#include <assert.h>
1516
#include <limits>
1617

18+
#include <sycl/ext/intel/experimental/esimd/detail/elem_type_traits.hpp>
19+
1720
#define SIMDCF_ELEMENT_SKIP(i)
1821

1922
__SYCL_INLINE_NAMESPACE(cl) {
@@ -46,7 +49,12 @@ static long long abs(long long a) {
4649
}
4750
}
4851

49-
template <typename RT> struct satur {
52+
template <typename RT, class SFINAE = void> struct satur;
53+
54+
template <typename RT>
55+
struct satur<RT, std::enable_if_t<std::is_integral_v<RT>>> {
56+
static_assert(!__SEIEED::is_wrapper_elem_type_v<RT>);
57+
5058
template <typename T> static RT saturate(const T val, const int flags) {
5159
if ((flags & sat_is_on) == 0) {
5260
return (RT)val;
@@ -72,35 +80,29 @@ template <typename RT> struct satur {
7280
}
7381
};
7482

75-
template <> struct satur<float> {
76-
template <typename T> static float saturate(const T val, const int flags) {
77-
if ((flags & sat_is_on) == 0) {
78-
return (float)val;
79-
}
80-
81-
if (val < 0.) {
82-
return 0;
83-
} else if (val > 1.) {
84-
return 1.;
85-
} else {
86-
return (float)val;
87-
}
88-
}
89-
};
90-
91-
template <> struct satur<double> {
92-
template <typename T> static double saturate(const T val, const int flags) {
93-
if ((flags & sat_is_on) == 0) {
94-
return (double)val;
83+
// Host implemenation of saturation for FP types, including non-standarad
84+
// wrapper types such as sycl::half. Template parameters are defined in terms
85+
// of user-level types (sycl::half), function parameter and return types -
86+
// in terms of raw bit representation type(_Float16 for half on device).
87+
template <class Tdst>
88+
struct satur<Tdst,
89+
std::enable_if_t<__SEIEED::is_generic_floating_point_v<Tdst>>> {
90+
template <typename Tsrc>
91+
static __SEIEED::__raw_t<Tdst> saturate(const __SEIEED::__raw_t<Tsrc> raw_src,
92+
const int flags) {
93+
Tsrc src = __SEIEED::bitcast_to_wrapper_type<Tsrc>(raw_src);
94+
95+
// perform comparison on user type!
96+
if ((flags & sat_is_on) == 0 || (src >= 0 && src <= 1)) {
97+
// convert_scalar accepts/returns user types - need to bitcast
98+
Tdst dst = __SEIEED::convert_scalar<Tdst, Tsrc>(src);
99+
return __SEIEED::bitcast_to_raw_type<Tdst>(dst);
95100
}
96-
97-
if (val < 0.) {
98-
return 0;
99-
} else if (val > 1.) {
100-
return 1.;
101-
} else {
102-
return (double)val;
101+
if (src < 0) {
102+
return __SEIEED::bitcast_to_raw_type<Tdst>(Tdst{0});
103103
}
104+
assert(src > 1);
105+
return __SEIEED::bitcast_to_raw_type<Tdst>(Tdst{1});
104106
}
105107
};
106108

@@ -116,6 +118,10 @@ template <> struct SetSatur<double, true> {
116118
static unsigned int set() { return sat_is_on; }
117119
};
118120

121+
// TODO replace restype_ex with detail::computation_type_t and represent half
122+
// as sycl::half rather than 'using half = sycl::detail::half_impl::half;'
123+
// above
124+
119125
// used for intermediate type in dp4a emulation
120126
template <typename T1, typename T2> struct restype_ex {
121127
private:
@@ -430,36 +436,6 @@ template <> struct fptype<float> { static const bool value = true; };
430436
template <typename T> struct dftype { static const bool value = false; };
431437
template <> struct dftype<double> { static const bool value = true; };
432438

433-
template <typename T> struct esimdtype;
434-
template <> struct esimdtype<char> { static const bool value = true; };
435-
436-
template <> struct esimdtype<signed char> { static const bool value = true; };
437-
438-
template <> struct esimdtype<unsigned char> { static const bool value = true; };
439-
440-
template <> struct esimdtype<short> { static const bool value = true; };
441-
442-
template <> struct esimdtype<unsigned short> {
443-
static const bool value = true;
444-
};
445-
template <> struct esimdtype<int> { static const bool value = true; };
446-
447-
template <> struct esimdtype<unsigned int> { static const bool value = true; };
448-
449-
template <> struct esimdtype<unsigned long> { static const bool value = true; };
450-
451-
template <> struct esimdtype<half> { static const bool value = true; };
452-
453-
template <> struct esimdtype<float> { static const bool value = true; };
454-
455-
template <> struct esimdtype<double> { static const bool value = true; };
456-
457-
template <> struct esimdtype<long long> { static const bool value = true; };
458-
459-
template <> struct esimdtype<unsigned long long> {
460-
static const bool value = true;
461-
};
462-
463439
template <typename T> struct bytetype;
464440
template <> struct bytetype<char> { static const bool value = true; };
465441
template <> struct bytetype<unsigned char> { static const bool value = true; };

0 commit comments

Comments
 (0)