Skip to content

Commit 6edca52

Browse files
shiltianbader
authored andcommitted
[SYCL] Add support for std::numeric_limits<cl::sycl::half> (#536)
Implemented the partial specilization of the template class `std::numeric_limits<HALF_TYPE>` where HALF_TYPE is `_Float16` on device side and `cl::sycl::detail::half_impl::half` on host side. Also defined some marcros corresponding to its FP32 counterpart like `HLF_MIN`, `HLF_MAX`, etc. Signed-off-by: Shilei Tian <[email protected]>
1 parent 916c32d commit 6edca52

File tree

2 files changed

+225
-23
lines changed

2 files changed

+225
-23
lines changed

sycl/include/CL/sycl/half_type.hpp

Lines changed: 128 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88

99
#pragma once
1010

11+
#include <cmath>
1112
#include <cstdint>
1213
#include <functional>
1314
#include <iostream>
15+
#include <limits>
1416

1517
namespace cl {
1618
namespace sycl {
@@ -73,8 +75,8 @@ class half {
7375
// on arithmetic types. We can't specify half type as arithmetic/floating
7476
// point(via std::is_floating_point) since only float, double and long double
7577
// types are "floating point" according to the standard. In order to use half
76-
// type with these math functions we cast half to float using template function
77-
// helper.
78+
// type with these math functions we cast half to float using template
79+
// function helper.
7880
template <typename T> inline T cast_if_host_half(T val) { return val; }
7981

8082
inline float cast_if_host_half(half_impl::half val) {
@@ -86,22 +88,136 @@ inline float cast_if_host_half(half_impl::half val) {
8688
} // namespace sycl
8789
} // namespace cl
8890

91+
#ifdef __SYCL_DEVICE_ONLY__
92+
using half = _Float16;
93+
#else
94+
using half = cl::sycl::detail::half_impl::half;
95+
#endif
96+
97+
// Partial specialization of some functions in namespace `std`
8998
namespace std {
9099

91-
template <> struct hash<cl::sycl::detail::half_impl::half> {
92-
size_t operator()(cl::sycl::detail::half_impl::half const &key) const
93-
noexcept {
94-
return hash<uint16_t>()(key.Buf);
100+
#ifdef __SYCL_DEVICE_ONLY__
101+
// `constexpr` could work because the implicit conversion from `float` to
102+
// `_Float16` can be `constexpr`.
103+
#define CONSTEXPR_QUALIFIER constexpr
104+
#else
105+
// The qualifier is `const` instead of `constexpr` that is original to be
106+
// because the constructor is not `constexpr` function.
107+
#define CONSTEXPR_QUALIFIER const
108+
#endif
109+
110+
// Partial specialization of `std::hash<cl::sycl::half>`
111+
template <> struct hash<half> {
112+
size_t operator()(half const &Key) const noexcept {
113+
return hash<uint16_t>{}(reinterpret_cast<const uint16_t &>(Key));
95114
}
96115
};
97116

98-
} // namespace std
117+
// Partial specialization of `std::numeric<cl::sycl::half>`
99118

100-
#ifdef __SYCL_DEVICE_ONLY__
101-
using half = _Float16;
102-
#else
103-
using half = cl::sycl::detail::half_impl::half;
104-
#endif
119+
// All following values are either calculated based on description of each
120+
// function/value on https://en.cppreference.com/w/cpp/types/numeric_limits, or
121+
// cl_platform.h.
122+
#define SYCL_HLF_MIN 6.103515625e-05F
123+
124+
#define SYCL_HLF_MAX 65504.0F
125+
126+
#define SYCL_HLF_MAX_10_EXP 4
127+
128+
#define SYCL_HLF_MAX_EXP 16
129+
130+
#define SYCL_HLF_MIN_10_EXP -4
131+
132+
#define SYCL_HLF_MIN_EXP -13
133+
134+
#define SYCL_HLF_MANT_DIG 11
135+
136+
#define SYCL_HLF_DIG 3
137+
138+
#define SYCL_HLF_DECIMAL_DIG 5
139+
140+
#define SYCL_HLF_EPSILON 9.765625e-04F
141+
142+
#define SYCL_HLF_RADIX 2
143+
144+
template <> struct numeric_limits<half> {
145+
static constexpr const bool is_specialized = true;
146+
147+
static constexpr const bool is_signed = true;
148+
149+
static constexpr const bool is_integer = false;
150+
151+
static constexpr const bool is_exact = false;
152+
153+
static constexpr const bool has_infinity = true;
154+
155+
static constexpr const bool has_quiet_NaN = true;
156+
157+
static constexpr const bool has_signaling_NaN = true;
158+
159+
static constexpr const float_denorm_style has_denorm = denorm_present;
160+
161+
static constexpr const bool has_denorm_loss = false;
162+
163+
static constexpr const bool tinyness_before = false;
164+
165+
static constexpr const bool traps = false;
166+
167+
static constexpr const int max_exponent10 = SYCL_HLF_MAX_10_EXP;
168+
169+
static constexpr const int max_exponent = SYCL_HLF_MAX_EXP;
170+
171+
static constexpr const int min_exponent10 = SYCL_HLF_MIN_10_EXP;
172+
173+
static constexpr const int min_exponent = SYCL_HLF_MIN_EXP;
174+
175+
static constexpr const int radix = SYCL_HLF_RADIX;
176+
177+
static constexpr const int max_digits10 = SYCL_HLF_DECIMAL_DIG;
178+
179+
static constexpr const int digits = SYCL_HLF_MANT_DIG;
180+
181+
static constexpr const bool is_bounded = true;
182+
183+
static constexpr const int digits10 = SYCL_HLF_DIG;
184+
185+
static constexpr const bool is_modulo = false;
186+
187+
static constexpr const bool is_iec559 = true;
188+
189+
static constexpr const float_round_style round_style = round_to_nearest;
190+
191+
static CONSTEXPR_QUALIFIER half min() noexcept { return SYCL_HLF_MIN; }
192+
193+
static CONSTEXPR_QUALIFIER half max() noexcept { return SYCL_HLF_MAX; }
194+
195+
static CONSTEXPR_QUALIFIER half lowest() noexcept { return -SYCL_HLF_MAX; }
196+
197+
static CONSTEXPR_QUALIFIER half epsilon() noexcept {
198+
return SYCL_HLF_EPSILON;
199+
}
200+
201+
static CONSTEXPR_QUALIFIER half round_error() noexcept { return 0.5F; }
202+
203+
static CONSTEXPR_QUALIFIER half infinity() noexcept {
204+
return __builtin_huge_valf();
205+
}
206+
207+
static CONSTEXPR_QUALIFIER half quiet_NaN() noexcept {
208+
return __builtin_nanf("");
209+
}
210+
211+
static CONSTEXPR_QUALIFIER half signaling_NaN() noexcept {
212+
return __builtin_nansf("");
213+
}
214+
215+
static CONSTEXPR_QUALIFIER half denorm_min() noexcept { return 5.96046e-08F; }
216+
};
217+
218+
#undef CONSTEXPR_QUALIFIER
219+
220+
} // namespace std
105221

106222
inline std::ostream &operator<<(std::ostream &O, half const &rhs) {
107223
O << static_cast<float>(rhs);

sycl/test/basic_tests/half_type.cpp

Lines changed: 97 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,22 @@
1515
#include <CL/sycl.hpp>
1616

1717
#include <cmath>
18+
#include <unordered_set>
1819

1920
using namespace cl::sycl;
2021

21-
constexpr float flt_epsilon = 9.77e-4;
22-
2322
constexpr size_t N = 100;
2423

25-
template <typename T> void assert_close(const T &C, const float ref) {
24+
template <typename T> void assert_close(const T &C, const cl::sycl::half ref) {
2625
for (size_t i = 0; i < N; i++) {
27-
float diff = C[i] - ref;
28-
assert(std::fabs(diff) < flt_epsilon);
26+
auto diff = C[i] - ref;
27+
assert(std::fabs(static_cast<float>(diff)) <
28+
std::numeric_limits<cl::sycl::half>::epsilon());
2929
}
3030
}
3131

3232
void verify_add(queue &q, buffer<half, 1> &a, buffer<half, 1> &b, range<1> &r,
33-
const float ref) {
33+
const half ref) {
3434
buffer<half, 1> c{r};
3535

3636
q.submit([&](handler &cgh) {
@@ -45,7 +45,7 @@ void verify_add(queue &q, buffer<half, 1> &a, buffer<half, 1> &b, range<1> &r,
4545
}
4646

4747
void verify_min(queue &q, buffer<half, 1> &a, buffer<half, 1> &b, range<1> &r,
48-
const float ref) {
48+
const half ref) {
4949
buffer<half, 1> c{r};
5050

5151
q.submit([&](handler &cgh) {
@@ -60,7 +60,7 @@ void verify_min(queue &q, buffer<half, 1> &a, buffer<half, 1> &b, range<1> &r,
6060
}
6161

6262
void verify_mul(queue &q, buffer<half, 1> &a, buffer<half, 1> &b, range<1> &r,
63-
const float ref) {
63+
const half ref) {
6464
buffer<half, 1> c{r};
6565

6666
q.submit([&](handler &cgh) {
@@ -103,17 +103,97 @@ void verify_vec(queue &q) {
103103
assert(e.get_access<access::mode::read>()[0] == 0);
104104
}
105105

106+
void verify_numeric_limits(queue &q) {
107+
// Verify on host side
108+
// Static member variables
109+
std::numeric_limits<cl::sycl::half>::is_specialized;
110+
std::numeric_limits<cl::sycl::half>::is_signed;
111+
std::numeric_limits<cl::sycl::half>::is_integer;
112+
std::numeric_limits<cl::sycl::half>::is_exact;
113+
std::numeric_limits<cl::sycl::half>::has_infinity;
114+
std::numeric_limits<cl::sycl::half>::has_quiet_NaN;
115+
std::numeric_limits<cl::sycl::half>::has_signaling_NaN;
116+
std::numeric_limits<cl::sycl::half>::has_denorm;
117+
std::numeric_limits<cl::sycl::half>::has_denorm_loss;
118+
std::numeric_limits<cl::sycl::half>::tinyness_before;
119+
std::numeric_limits<cl::sycl::half>::traps;
120+
std::numeric_limits<cl::sycl::half>::max_exponent10;
121+
std::numeric_limits<cl::sycl::half>::max_exponent;
122+
std::numeric_limits<cl::sycl::half>::min_exponent10;
123+
std::numeric_limits<cl::sycl::half>::min_exponent;
124+
std::numeric_limits<cl::sycl::half>::radix;
125+
std::numeric_limits<cl::sycl::half>::max_digits10;
126+
std::numeric_limits<cl::sycl::half>::digits;
127+
std::numeric_limits<cl::sycl::half>::is_bounded;
128+
std::numeric_limits<cl::sycl::half>::digits10;
129+
std::numeric_limits<cl::sycl::half>::is_modulo;
130+
std::numeric_limits<cl::sycl::half>::is_iec559;
131+
std::numeric_limits<cl::sycl::half>::round_style;
132+
133+
// Static member functions
134+
std::numeric_limits<cl::sycl::half>::min();
135+
std::numeric_limits<cl::sycl::half>::max();
136+
std::numeric_limits<cl::sycl::half>::lowest();
137+
std::numeric_limits<cl::sycl::half>::epsilon();
138+
std::numeric_limits<cl::sycl::half>::round_error();
139+
std::numeric_limits<cl::sycl::half>::infinity();
140+
std::numeric_limits<cl::sycl::half>::quiet_NaN();
141+
std::numeric_limits<cl::sycl::half>::signaling_NaN();
142+
std::numeric_limits<cl::sycl::half>::denorm_min();
143+
144+
// Verify in kernel function for device side check
145+
q.submit([&](cl::sycl::handler &cgh) {
146+
cgh.single_task<class kernel>([]() {
147+
// Static member variables
148+
std::numeric_limits<cl::sycl::half>::is_specialized;
149+
std::numeric_limits<cl::sycl::half>::is_signed;
150+
std::numeric_limits<cl::sycl::half>::is_integer;
151+
std::numeric_limits<cl::sycl::half>::is_exact;
152+
std::numeric_limits<cl::sycl::half>::has_infinity;
153+
std::numeric_limits<cl::sycl::half>::has_quiet_NaN;
154+
std::numeric_limits<cl::sycl::half>::has_signaling_NaN;
155+
std::numeric_limits<cl::sycl::half>::has_denorm;
156+
std::numeric_limits<cl::sycl::half>::has_denorm_loss;
157+
std::numeric_limits<cl::sycl::half>::tinyness_before;
158+
std::numeric_limits<cl::sycl::half>::traps;
159+
std::numeric_limits<cl::sycl::half>::max_exponent10;
160+
std::numeric_limits<cl::sycl::half>::max_exponent;
161+
std::numeric_limits<cl::sycl::half>::min_exponent10;
162+
std::numeric_limits<cl::sycl::half>::min_exponent;
163+
std::numeric_limits<cl::sycl::half>::radix;
164+
std::numeric_limits<cl::sycl::half>::max_digits10;
165+
std::numeric_limits<cl::sycl::half>::digits;
166+
std::numeric_limits<cl::sycl::half>::is_bounded;
167+
std::numeric_limits<cl::sycl::half>::digits10;
168+
std::numeric_limits<cl::sycl::half>::is_modulo;
169+
std::numeric_limits<cl::sycl::half>::is_iec559;
170+
std::numeric_limits<cl::sycl::half>::round_style;
171+
172+
// Static member functions
173+
std::numeric_limits<cl::sycl::half>::min();
174+
std::numeric_limits<cl::sycl::half>::max();
175+
std::numeric_limits<cl::sycl::half>::lowest();
176+
std::numeric_limits<cl::sycl::half>::epsilon();
177+
std::numeric_limits<cl::sycl::half>::round_error();
178+
std::numeric_limits<cl::sycl::half>::infinity();
179+
std::numeric_limits<cl::sycl::half>::quiet_NaN();
180+
std::numeric_limits<cl::sycl::half>::signaling_NaN();
181+
std::numeric_limits<cl::sycl::half>::denorm_min();
182+
});
183+
});
184+
}
185+
106186
inline bool bitwise_comparison_fp16(const half val, const uint16_t exp) {
107-
return reinterpret_cast<const uint16_t&>(val) == exp;
187+
return reinterpret_cast<const uint16_t &>(val) == exp;
108188
}
109189

110190
inline bool bitwise_comparison_fp32(const half val, const uint32_t exp) {
111191
const float fp32 = static_cast<float>(val);
112-
return reinterpret_cast<const uint32_t&>(fp32) == exp;
192+
return reinterpret_cast<const uint32_t &>(fp32) == exp;
113193
}
114194

115195
int main() {
116-
// We assert that the length is 1 because we use macro to select the device
196+
// We assert that the length is 1 because we use env to select the device
117197
assert(device::get_devices().size() == 1);
118198

119199
auto dev = device::get_devices()[0];
@@ -137,6 +217,7 @@ int main() {
137217
verify_mul(q, a, b, r, 10.0);
138218
verify_div(q, a, b, r, 2.5);
139219
verify_vec(q);
220+
verify_numeric_limits(q);
140221

141222
if (!dev.is_host()) {
142223
return 0;
@@ -197,5 +278,10 @@ int main() {
197278
assert(bitwise_comparison_fp32(reinterpret_cast<const half &>(subnormal),
198279
882900992));
199280

281+
// std::hash<cl::sycl::half>
282+
std::unordered_set<half> sets;
283+
sets.insert(1.2);
284+
assert(sets.find(1.2) != sets.end());
285+
200286
return 0;
201287
}

0 commit comments

Comments
 (0)