Skip to content

Commit 3aabd26

Browse files
authored
[SYCL] Stop reinterpret from changing buffer_allocator type to const (#6769)
This PR prevents the `buffer_allocator` from being rebound to a const type. Currently when a `buffer` is reinterpreted from type `T` to type `const T` the `buffer_allocator` type is also changed to `const T`. This does not agree with the SYCL spec which states "A buffer of data type const T uses buffer_allocator<T> by default." This PR also adds a compile time test which validates the returned type.
1 parent a3a88bc commit 3aabd26

File tree

2 files changed

+76
-6
lines changed

2 files changed

+76
-6
lines changed

sycl/include/sycl/buffer.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -627,7 +627,7 @@ class buffer : public detail::buffer_plain {
627627
template <typename ReinterpretT, int ReinterpretDim>
628628
buffer<ReinterpretT, ReinterpretDim,
629629
typename std::allocator_traits<AllocatorT>::template rebind_alloc<
630-
ReinterpretT>>
630+
std::remove_const_t<ReinterpretT>>>
631631
reinterpret(range<ReinterpretDim> reinterpretRange) const {
632632
if (sizeof(ReinterpretT) * reinterpretRange.size() != byte_size())
633633
throw sycl::invalid_object_error(
@@ -637,8 +637,8 @@ class buffer : public detail::buffer_plain {
637637
PI_ERROR_INVALID_VALUE);
638638

639639
return buffer<ReinterpretT, ReinterpretDim,
640-
typename std::allocator_traits<
641-
AllocatorT>::template rebind_alloc<ReinterpretT>>(
640+
typename std::allocator_traits<AllocatorT>::
641+
template rebind_alloc<std::remove_const_t<ReinterpretT>>>(
642642
impl, reinterpretRange, OffsetInBytes, IsSubBuffer);
643643
}
644644

@@ -647,11 +647,11 @@ class buffer : public detail::buffer_plain {
647647
(sizeof(ReinterpretT) == sizeof(T)) && (dimensions == ReinterpretDim),
648648
buffer<ReinterpretT, ReinterpretDim,
649649
typename std::allocator_traits<AllocatorT>::template rebind_alloc<
650-
ReinterpretT>>>::type
650+
std::remove_const_t<ReinterpretT>>>>::type
651651
reinterpret() const {
652652
return buffer<ReinterpretT, ReinterpretDim,
653-
typename std::allocator_traits<
654-
AllocatorT>::template rebind_alloc<ReinterpretT>>(
653+
typename std::allocator_traits<AllocatorT>::
654+
template rebind_alloc<std::remove_const_t<ReinterpretT>>>(
655655
impl, get_range(), OffsetInBytes, IsSubBuffer);
656656
}
657657

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// RUN: %clangxx -fsycl -fsyntax-only %s
2+
3+
#include <sycl/sycl.hpp>
4+
5+
template <int Dimensions> sycl::range<Dimensions> create_range() {
6+
return sycl::range<Dimensions>(1);
7+
}
8+
9+
template <> sycl::range<2> create_range() { return sycl::range<2>(1, 1); }
10+
11+
template <> sycl::range<3> create_range() { return sycl::range<3>(1, 1, 1); }
12+
13+
// Compile only test to check that buffer_allocator type does not get
14+
// reinterpreted with const keyworkd
15+
template <typename T, int Dimensions,
16+
typename Allocator = sycl::buffer_allocator<T>>
17+
void test_buffer_const_reinterpret() {
18+
sycl::buffer<T, Dimensions, Allocator> buff(create_range<Dimensions>());
19+
sycl::buffer<T const, Dimensions, Allocator> const_buff(
20+
create_range<Dimensions>());
21+
22+
auto reinterpret_buff = buff.template reinterpret<T const, Dimensions>(
23+
create_range<Dimensions>());
24+
25+
static_assert(
26+
std::is_same_v<decltype(const_buff), decltype(reinterpret_buff)>);
27+
}
28+
29+
struct my_struct {
30+
int my_int = 0;
31+
float my_float = 0;
32+
double my_double = 0;
33+
};
34+
35+
int main() {
36+
test_buffer_const_reinterpret<short, 1>();
37+
test_buffer_const_reinterpret<int, 1>();
38+
test_buffer_const_reinterpret<long, 1>();
39+
test_buffer_const_reinterpret<unsigned short, 1>();
40+
test_buffer_const_reinterpret<unsigned int, 1>();
41+
test_buffer_const_reinterpret<unsigned long, 1>();
42+
test_buffer_const_reinterpret<sycl::half, 1>();
43+
test_buffer_const_reinterpret<float, 1>();
44+
test_buffer_const_reinterpret<double, 1>();
45+
test_buffer_const_reinterpret<my_struct, 1>();
46+
47+
test_buffer_const_reinterpret<short, 2>();
48+
test_buffer_const_reinterpret<int, 2>();
49+
test_buffer_const_reinterpret<long, 2>();
50+
test_buffer_const_reinterpret<unsigned short, 2>();
51+
test_buffer_const_reinterpret<unsigned int, 2>();
52+
test_buffer_const_reinterpret<unsigned long, 2>();
53+
test_buffer_const_reinterpret<sycl::half, 2>();
54+
test_buffer_const_reinterpret<float, 2>();
55+
test_buffer_const_reinterpret<double, 2>();
56+
test_buffer_const_reinterpret<my_struct, 2>();
57+
58+
test_buffer_const_reinterpret<short, 3>();
59+
test_buffer_const_reinterpret<int, 3>();
60+
test_buffer_const_reinterpret<long, 3>();
61+
test_buffer_const_reinterpret<unsigned short, 3>();
62+
test_buffer_const_reinterpret<unsigned int, 3>();
63+
test_buffer_const_reinterpret<unsigned long, 3>();
64+
test_buffer_const_reinterpret<sycl::half, 3>();
65+
test_buffer_const_reinterpret<float, 3>();
66+
test_buffer_const_reinterpret<double, 3>();
67+
test_buffer_const_reinterpret<my_struct, 3>();
68+
69+
return 0;
70+
}

0 commit comments

Comments
 (0)