Skip to content

Commit 6ac0a3f

Browse files
joeatoddAlcpz
andauthored
[SYCL][COMPAT] Add bfe_safe and bfi_safe APIs (#14006)
This PR adds bit-field extract (`bfe_safe`) and bit-field insert (`bfi_safe`) to the `math.hpp` header. These are 'bounds checked' variants of the `detail::bfe` and `detail::bfi` APIs respectively, though in addition to bounds checking the `_safe` variants also provide: - asm for NVPTX - Proper treatment of signed types (`bfe_safe`) As such, it's not clear whether the 'unsafe' variants ought to be exposed at all and so I've put them in `detail::` for now. What are the expected semantics in relation to the `_safe` variants? They would likely need separate tests, and it's not clear that DPCT use these. --------- Signed-off-by: Joe Todd <[email protected]> Co-authored-by: Alberto Cabrera Pérez <[email protected]>
1 parent 9e28bba commit 6ac0a3f

File tree

4 files changed

+543
-0
lines changed

4 files changed

+543
-0
lines changed

sycl/doc/syclcompat/README.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2118,6 +2118,48 @@ template <typename RetT, typename AT, typename BT>
21182118
inline constexpr RetT extend_vavrg2_sat(AT a, BT b, RetT c);
21192119
```
21202120

2121+
The math header file provides APIs for bit-field insertion (`bfi_safe`) and
2122+
bit-field extraction (`bfe_safe`). These are bounds-checked variants of
2123+
underlying `detail` APIs (`detail::bfi`, `detail::bfe`) which, in future
2124+
releases, will be exposed to the user.
2125+
2126+
```c++
2127+
2128+
/// Bitfield-insert with boundary checking.
2129+
///
2130+
/// Align and insert a bit field from \param x into \param y . Source \param
2131+
/// bit_start gives the starting bit position for the insertion, and source
2132+
/// \param num_bits gives the bit field length in bits.
2133+
///
2134+
/// \tparam T The type of \param x and \param y , must be an unsigned integer.
2135+
/// \param x The source of the bitfield.
2136+
/// \param y The source where bitfield is inserted.
2137+
/// \param bit_start The position to start insertion.
2138+
/// \param num_bits The number of bits to insertion.
2139+
template <typename T>
2140+
inline T bfi_safe(const T x, const T y, const uint32_t bit_start,
2141+
const uint32_t num_bits);
2142+
2143+
/// Bitfield-extract with boundary checking.
2144+
///
2145+
/// Extract bit field from \param source and return the zero or sign-extended
2146+
/// result. Source \param bit_start gives the bit field starting bit position,
2147+
/// and source \param num_bits gives the bit field length in bits.
2148+
///
2149+
/// The result is padded with the sign bit of the extracted field. If `num_bits`
2150+
/// is zero, the result is zero. If the start position is beyond the msb of the
2151+
/// input, the result is filled with the replicated sign bit of the extracted
2152+
/// field.
2153+
///
2154+
/// \tparam T The type of \param source value, must be an integer.
2155+
/// \param source The source value to extracting.
2156+
/// \param bit_start The position to start extracting.
2157+
/// \param num_bits The number of bits to extracting.
2158+
template <typename T>
2159+
inline T bfe_safe(const T source, const uint32_t bit_start,
2160+
const uint32_t num_bits);
2161+
```
2162+
21212163
## Sample Code
21222164
21232165
Below is a simple linear algebra sample, which computes `y = mx + b` implemented

sycl/include/syclcompat/math.hpp

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,160 @@ inline bool isnan(const sycl::ext::oneapi::bfloat16 a) {
186186
}
187187
#endif
188188

189+
// FIXME(syclcompat-lib-reviewers): move bfe outside detail once perf is
190+
// improved & semantics understood
191+
/// Bitfield-extract.
192+
///
193+
/// \tparam T The type of \param source value, must be an integer.
194+
/// \param source The source value to extracting.
195+
/// \param bit_start The position to start extracting.
196+
/// \param num_bits The number of bits to extracting.
197+
template <typename T>
198+
inline T bfe(const T source, const uint32_t bit_start,
199+
const uint32_t num_bits) {
200+
static_assert(std::is_unsigned_v<T>);
201+
// FIXME(syclcompat-lib-reviewers): This ternary was added to catch a case
202+
// which may be undefined anyway. Consider that we are losing perf here.
203+
const T mask =
204+
num_bits >= CHAR_BIT * sizeof(T) ? T{-1} : ((T{1} << num_bits) - 1);
205+
return (source >> bit_start) & mask;
206+
}
207+
189208
} // namespace detail
190209

210+
/// Bitfield-extract with boundary checking.
211+
///
212+
/// Extract bit field from \param source and return the zero or sign-extended
213+
/// result. Source \param bit_start gives the bit field starting bit position,
214+
/// and source \param num_bits gives the bit field length in bits.
215+
///
216+
/// The result is padded with the sign bit of the extracted field. If `num_bits`
217+
/// is zero, the result is zero. If the start position is beyond the msb of the
218+
/// input, the result is filled with the replicated sign bit of the extracted
219+
/// field.
220+
///
221+
/// \tparam T The type of \param source value, must be an integer.
222+
/// \param source The source value to extracting.
223+
/// \param bit_start The position to start extracting.
224+
/// \param num_bits The number of bits to extracting.
225+
template <typename T>
226+
inline T bfe_safe(const T source, const uint32_t bit_start,
227+
const uint32_t num_bits) {
228+
static_assert(std::is_integral_v<T>);
229+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
230+
if constexpr (std::is_same_v<T, int8_t> || std::is_same_v<T, int16_t> ||
231+
std::is_same_v<T, int32_t>) {
232+
int32_t res{};
233+
asm volatile("bfe.s32 %0, %1, %2, %3;"
234+
: "=r"(res)
235+
: "r"((int32_t)source), "r"(bit_start), "r"(num_bits));
236+
return res;
237+
} else if constexpr (std::is_same_v<T, uint8_t> ||
238+
std::is_same_v<T, uint16_t> ||
239+
std::is_same_v<T, uint32_t>) {
240+
uint32_t res{};
241+
asm volatile("bfe.u32 %0, %1, %2, %3;"
242+
: "=r"(res)
243+
: "r"((uint32_t)source), "r"(bit_start), "r"(num_bits));
244+
return res;
245+
} else if constexpr (std::is_same_v<T, int64_t>) {
246+
T res{};
247+
asm volatile("bfe.s64 %0, %1, %2, %3;"
248+
: "=l"(res)
249+
: "l"(source), "r"(bit_start), "r"(num_bits));
250+
return res;
251+
} else if constexpr (std::is_same_v<T, uint64_t>) {
252+
T res{};
253+
asm volatile("bfe.u64 %0, %1, %2, %3;"
254+
: "=l"(res)
255+
: "l"(source), "r"(bit_start), "r"(num_bits));
256+
return res;
257+
}
258+
#endif
259+
const uint32_t bit_width = CHAR_BIT * sizeof(T);
260+
const uint32_t pos = std::min(bit_start, bit_width);
261+
const uint32_t len = std::min(pos + num_bits, bit_width) - pos;
262+
if constexpr (std::is_signed_v<T>) {
263+
// FIXME(syclcompat-lib-reviewers): As above, catching a case whose result
264+
// is undefined and likely losing perf.
265+
const T mask = len >= bit_width ? T{-1} : static_cast<T>((T{1} << len) - 1);
266+
267+
// Find the sign-bit, the result is padded with the sign bit of the
268+
// extracted field.
269+
// Note if requested num_bits==0, we return zero via sign_bit=0
270+
const uint32_t sign_bit_pos = std::min(pos + len - 1, bit_width - 1);
271+
const T sign_bit = num_bits != 0 && ((source >> sign_bit_pos) & 1);
272+
const T sign_bit_padding = (-sign_bit & ~mask);
273+
return ((source >> pos) & mask) | sign_bit_padding;
274+
} else {
275+
return syclcompat::detail::bfe(source, pos, len);
276+
}
277+
}
278+
279+
namespace detail {
280+
// FIXME(syclcompat-lib-reviewers): move bfi outside detail once perf is
281+
// improved & semantics understood
282+
/// Bitfield-insert.
283+
///
284+
/// \tparam T The type of \param x and \param y , must be an unsigned integer.
285+
/// \param x The source of the bitfield.
286+
/// \param y The source where bitfield is inserted.
287+
/// \param bit_start The position to start insertion.
288+
/// \param num_bits The number of bits to insertion.
289+
template <typename T>
290+
inline T bfi(const T x, const T y, const uint32_t bit_start,
291+
const uint32_t num_bits) {
292+
static_assert(std::is_unsigned_v<T>);
293+
constexpr unsigned bit_width = CHAR_BIT * sizeof(T);
294+
295+
// if bit_start > bit_width || len == 0, should return y.
296+
const T ignore_bfi = static_cast<T>(bit_start > bit_width || num_bits == 0);
297+
T extract_bitfield_mask = (static_cast<T>(~T{0}) >> (bit_width - num_bits))
298+
<< bit_start;
299+
T clean_bitfield_mask = ~extract_bitfield_mask;
300+
return (y & (-ignore_bfi | clean_bitfield_mask)) |
301+
(~-ignore_bfi & ((x << bit_start) & extract_bitfield_mask));
302+
}
303+
} // namespace detail
304+
305+
/// Bitfield-insert with boundary checking.
306+
///
307+
/// Align and insert a bit field from \param x into \param y . Source \param
308+
/// bit_start gives the starting bit position for the insertion, and source
309+
/// \param num_bits gives the bit field length in bits.
310+
///
311+
/// \tparam T The type of \param x and \param y , must be an unsigned integer.
312+
/// \param x The source of the bitfield.
313+
/// \param y The source where bitfield is inserted.
314+
/// \param bit_start The position to start insertion.
315+
/// \param num_bits The number of bits to insertion.
316+
template <typename T>
317+
inline T bfi_safe(const T x, const T y, const uint32_t bit_start,
318+
const uint32_t num_bits) {
319+
static_assert(std::is_unsigned_v<T>);
320+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
321+
if constexpr (std::is_same_v<T, uint8_t> || std::is_same_v<T, uint16_t> ||
322+
std::is_same_v<T, uint32_t>) {
323+
uint32_t res{};
324+
asm volatile("bfi.b32 %0, %1, %2, %3, %4;"
325+
: "=r"(res)
326+
: "r"((uint32_t)x), "r"((uint32_t)y), "r"(bit_start),
327+
"r"(num_bits));
328+
return res;
329+
} else if constexpr (std::is_same_v<T, uint64_t>) {
330+
uint64_t res{};
331+
asm volatile("bfi.b64 %0, %1, %2, %3, %4;"
332+
: "=l"(res)
333+
: "l"(x), "l"(y), "r"(bit_start), "r"(num_bits));
334+
return res;
335+
}
336+
#endif
337+
constexpr unsigned bit_width = CHAR_BIT * sizeof(T);
338+
const uint32_t pos = std::min(bit_start, bit_width);
339+
const uint32_t len = std::min(pos + num_bits, bit_width) - pos;
340+
return syclcompat::detail::bfi(x, y, pos, len);
341+
}
342+
191343
/// Emulated function for __funnelshift_l
192344
inline unsigned int funnelshift_l(unsigned int low, unsigned int high,
193345
unsigned int shift) {
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
/***************************************************************************
2+
*
3+
* Copyright (C) Codeplay Software Ltd.
4+
*
5+
* Part of the LLVM Project, under the Apache License v2.0 with LLVM
6+
* Exceptions. See https://llvm.org/LICENSE.txt for license information.
7+
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*
15+
* SYCLcompat API
16+
*
17+
* math_bfe.cpp
18+
*
19+
* Description:
20+
* math bitfield extract tests
21+
**************************************************************************/
22+
23+
// ===----------- math_bfe.cpp ------------------ -*- C++ -* --------------===//
24+
//
25+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
26+
// See https://llvm.org/LICENSE.txt for license information.
27+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
28+
//
29+
//
30+
// ===---------------------------------------------------------------------===//
31+
32+
// RUN: %clangxx -std=c++17 -fsycl -fsycl-targets=%{sycl_triple} %s -o %t.out
33+
// RUN: %{run} %t.out
34+
35+
#include <bitset>
36+
#include <chrono>
37+
#include <iostream>
38+
#include <limits.h>
39+
#include <random>
40+
#include <stdint.h>
41+
#include <sycl/detail/core.hpp>
42+
#include <syclcompat/math.hpp>
43+
#include <type_traits>
44+
#include <vector>
45+
46+
template <typename T>
47+
inline std::enable_if_t<std::is_integral_v<T>, T>
48+
bfe_slow(const T source, const uint32_t bit_start, const uint32_t num_bits) {
49+
const uint32_t msb = CHAR_BIT * sizeof(T) - 1;
50+
const uint32_t pos = bit_start;
51+
const uint32_t len = num_bits;
52+
53+
// If the requested bit field length is zero, the result is zero.
54+
if (num_bits == 0)
55+
return 0ULL;
56+
57+
T sbit;
58+
std::bitset<CHAR_BIT * sizeof(T)> source_bitset(source);
59+
if (std::is_unsigned_v<T> || len == 0)
60+
sbit = 0;
61+
else
62+
sbit = source_bitset[std::min(pos + len - 1, msb)];
63+
64+
// If the start position is beyond the msb of the input, the destination d is
65+
// filled with the replicated sign bit of the extracted field.
66+
// -1 is 1111...
67+
if (bit_start > msb)
68+
return -sbit;
69+
70+
std::bitset<CHAR_BIT * sizeof(T)> result_bitset;
71+
for (uint8_t i = 0; i <= msb; ++i)
72+
result_bitset[i] =
73+
(i < len && pos + i <= msb) ? source_bitset[pos + i] : sbit;
74+
return result_bitset.to_ullong();
75+
}
76+
77+
template <typename T> bool test(const char *Msg, int N) {
78+
uint32_t bit_width = CHAR_BIT * sizeof(T);
79+
T min_value = std::numeric_limits<T>::min();
80+
T max_value = std::numeric_limits<T>::max();
81+
std::random_device rd;
82+
std::mt19937::result_type seed =
83+
rd() ^
84+
((std::mt19937::result_type)
85+
std::chrono::duration_cast<std::chrono::seconds>(
86+
std::chrono::system_clock::now().time_since_epoch())
87+
.count() +
88+
(std::mt19937::result_type)
89+
std::chrono::duration_cast<std::chrono::microseconds>(
90+
std::chrono::high_resolution_clock::now().time_since_epoch())
91+
.count());
92+
93+
std::mt19937 gen(seed);
94+
std::uniform_int_distribution<T> rd_source(min_value, max_value);
95+
96+
// Define a small overshoot so that we adequately test out-of-range cases
97+
// without sacrificing depth of testing of valid start+length combinations
98+
constexpr uint32_t overshoot = 2;
99+
std::uniform_int_distribution<uint32_t> rd_start(0, bit_width + overshoot);
100+
std::uniform_int_distribution<uint32_t> rd_length(0, bit_width + overshoot);
101+
102+
std::vector<T> sources(N, 0);
103+
std::vector<T> compat_results(N, 0);
104+
std::vector<T> slow_results(N, 0);
105+
std::vector<uint32_t> starts(N, 0);
106+
std::vector<uint32_t> lengths(N, 0);
107+
for (int i = 0; i < N; ++i) {
108+
sources[i] = rd_source(gen);
109+
starts[i] = rd_start(gen);
110+
lengths[i] = rd_length(gen);
111+
}
112+
113+
sycl::buffer<T, 1> source_buffer(sources.data(), N);
114+
sycl::buffer<T, 1> compat_results_buffer(compat_results.data(), N);
115+
sycl::buffer<T, 1> slow_results_buffer(slow_results.data(), N);
116+
sycl::buffer<uint32_t, 1> starts_buffer(starts.data(), N);
117+
sycl::buffer<uint32_t, 1> lengths_buffer(lengths.data(), N);
118+
119+
sycl::queue que;
120+
que.submit([&](sycl::handler &handler) {
121+
sycl::accessor source_accessor(source_buffer, handler, sycl::read_only);
122+
sycl::accessor start_accessor(starts_buffer, handler, sycl::read_only);
123+
sycl::accessor length_accessor(lengths_buffer, handler, sycl::read_only);
124+
sycl::accessor compat_result_accessor(compat_results_buffer, handler,
125+
sycl::write_only);
126+
handler.parallel_for(N, [=](sycl::id<1> i) {
127+
compat_result_accessor[i] = syclcompat::bfe_safe<T>(
128+
source_accessor[i], start_accessor[i], length_accessor[i]);
129+
});
130+
});
131+
132+
que.submit([&](sycl::handler &handler) {
133+
sycl::accessor source_accessor(source_buffer, handler, sycl::read_only);
134+
sycl::accessor start_accessor(starts_buffer, handler, sycl::read_only);
135+
sycl::accessor length_accessor(lengths_buffer, handler, sycl::read_only);
136+
sycl::accessor slow_result_accessor(slow_results_buffer, handler,
137+
sycl::write_only);
138+
handler.parallel_for(N, [=](sycl::id<1> i) {
139+
slow_result_accessor[i] = bfe_slow<T>(
140+
source_accessor[i], start_accessor[i], length_accessor[i]);
141+
});
142+
});
143+
144+
que.wait_and_throw();
145+
sycl::host_accessor source_accessor(source_buffer, sycl::read_only);
146+
sycl::host_accessor start_accessor(starts_buffer, sycl::read_only);
147+
sycl::host_accessor length_accessor(lengths_buffer, sycl::read_only);
148+
sycl::host_accessor compat_result_accessor(compat_results_buffer,
149+
sycl::read_only);
150+
sycl::host_accessor slow_result_accessor(slow_results_buffer,
151+
sycl::read_only);
152+
153+
int failed = 0;
154+
for (int i = 0; i < N; ++i) {
155+
if (compat_result_accessor[i] != slow_result_accessor[i]) {
156+
failed++;
157+
std::cout << "[source = " << source_accessor[i]
158+
<< ", bit_start = " << start_accessor[i]
159+
<< ", num_bits = " << length_accessor[i] << "] failed, expect "
160+
<< slow_result_accessor[i] << " but got "
161+
<< compat_result_accessor[i] << std::endl;
162+
}
163+
}
164+
std::cout << "===============" << std::endl;
165+
std::cout << "Test: " << Msg << std::endl;
166+
std::cout << "Total: " << N << std::endl;
167+
std::cout << "Success: " << N - failed << std::endl;
168+
std::cout << "Failed: " << failed << std::endl;
169+
std::cout << "===============" << std::endl;
170+
return !failed;
171+
}
172+
173+
int main() {
174+
const int N = 1000;
175+
assert(test<int16_t>("int16", N));
176+
assert(test<uint16_t>("uint16", N));
177+
assert(test<int32_t>("int32", N));
178+
assert(test<uint32_t>("uint32", N));
179+
assert(test<int64_t>("int64", N));
180+
assert(test<uint64_t>("uint64", N));
181+
return 0;
182+
}

0 commit comments

Comments
 (0)