Skip to content

Commit 6314af8

Browse files
authored
[SYCL] Reflect get_pointer updates to legacy specialized multi_ptr. (intel#10417)
* Update the legacy multi_ptr construction from local_accessors to conform with the construction method used by other accessors. * Implement a test for the legacy multi_ptr construction using local_accessor. * Use enable_if in the copy constructor and assignment operator of the legacy/void/const-void multi_ptr specialization to prevent errors caused by multiple overloads of 'multi_ptr'.
1 parent 96d7c4a commit 6314af8

File tree

2 files changed

+43
-3
lines changed

2 files changed

+43
-3
lines changed

sycl/include/sycl/multi_ptr.hpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -935,7 +935,7 @@ class __SYCL2020_DEPRECATED(
935935
std::is_const_v<ET> && std::is_same_v<ET, ElementType>>>
936936
multi_ptr(
937937
local_accessor<typename std::remove_const_t<ET>, dimensions> Accessor)
938-
: m_Pointer(detail::cast_AS<pointer_t>(Accessor.get_pointer())) {}
938+
: multi_ptr(Accessor.get_pointer()) {}
939939

940940
// Only if Space == constant_space and element type is const
941941
template <
@@ -1103,6 +1103,10 @@ class __SYCL2020_DEPRECATED(
11031103
multi_ptr(multi_ptr &&) = default;
11041104
multi_ptr(pointer_t pointer) : m_Pointer(pointer) {}
11051105
#ifdef __SYCL_DEVICE_ONLY__
1106+
template <
1107+
typename RelayPointerT = pointer_t,
1108+
typename = std::enable_if_t<std::is_same_v<RelayPointerT, pointer_t> &&
1109+
!std::is_same_v<RelayPointerT, void *>>>
11061110
multi_ptr(void *pointer) : m_Pointer(detail::cast_AS<pointer_t>(pointer)) {
11071111
// TODO An implementation should reject an argument if the deduced
11081112
// address space is not compatible with Space.
@@ -1133,6 +1137,10 @@ class __SYCL2020_DEPRECATED(
11331137
return *this;
11341138
}
11351139
#ifdef __SYCL_DEVICE_ONLY__
1140+
template <
1141+
typename RelayPointerT = pointer_t,
1142+
typename = std::enable_if_t<std::is_same_v<RelayPointerT, pointer_t> &&
1143+
!std::is_same_v<RelayPointerT, void *>>>
11361144
multi_ptr &operator=(void *pointer) {
11371145
// TODO An implementation should reject an argument if the deduced
11381146
// address space is not compatible with Space.
@@ -1180,7 +1188,7 @@ class __SYCL2020_DEPRECATED(
11801188
_Space == Space && (Space == access::address_space::generic_space ||
11811189
Space == access::address_space::local_space)>>
11821190
multi_ptr(local_accessor<ElementType, dimensions> Accessor)
1183-
: m_Pointer(detail::cast_AS<pointer_t>(Accessor.get_pointer())) {}
1191+
: multi_ptr(Accessor.get_pointer()) {}
11841192

11851193
// Only if Space == constant_space
11861194
template <
@@ -1251,6 +1259,10 @@ class __SYCL2020_DEPRECATED(
12511259
multi_ptr(multi_ptr &&) = default;
12521260
multi_ptr(pointer_t pointer) : m_Pointer(pointer) {}
12531261
#ifdef __SYCL_DEVICE_ONLY__
1262+
template <
1263+
typename RelayPointerT = pointer_t,
1264+
typename = std::enable_if_t<std::is_same_v<RelayPointerT, pointer_t> &&
1265+
!std::is_same_v<RelayPointerT, const void *>>>
12541266
multi_ptr(const void *pointer)
12551267
: m_Pointer(detail::cast_AS<pointer_t>(pointer)) {
12561268
// TODO An implementation should reject an argument if the deduced
@@ -1282,6 +1294,10 @@ class __SYCL2020_DEPRECATED(
12821294
return *this;
12831295
}
12841296
#ifdef __SYCL_DEVICE_ONLY__
1297+
template <
1298+
typename RelayPointerT = pointer_t,
1299+
typename = std::enable_if_t<std::is_same_v<RelayPointerT, pointer_t> &&
1300+
!std::is_same_v<RelayPointerT, const void *>>>
12851301
multi_ptr &operator=(const void *pointer) {
12861302
// TODO An implementation should reject an argument if the deduced
12871303
// address space is not compatible with Space.
@@ -1329,7 +1345,7 @@ class __SYCL2020_DEPRECATED(
13291345
_Space == Space && (Space == access::address_space::generic_space ||
13301346
Space == access::address_space::local_space)>>
13311347
multi_ptr(local_accessor<ElementType, dimensions> Accessor)
1332-
: m_Pointer(detail::cast_AS<pointer_t>(Accessor.get_pointer())) {}
1348+
: multi_ptr(Accessor.get_pointer()) {}
13331349

13341350
// Only if Space == constant_space
13351351
template <

sycl/test-e2e/Basic/multi_ptr_legacy.hpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,30 @@ template <typename T> void testMultPtr() {
7676
access::decorated::legacy>(
7777
localAccessor.get_pointer());
7878

79+
auto local_ptr2 =
80+
multi_ptr<T, access::address_space::local_space,
81+
access::decorated::legacy>(localAccessor);
82+
83+
auto local_ptr3 =
84+
multi_ptr<void, access::address_space::local_space,
85+
access::decorated::legacy>(localAccessor);
86+
87+
auto local_ptr4 =
88+
multi_ptr<const void, access::address_space::local_space,
89+
access::decorated::legacy>(localAccessor);
90+
91+
auto local_ptr5 =
92+
multi_ptr<T, access::address_space::generic_space,
93+
access::decorated::legacy>(localAccessor);
94+
95+
auto local_ptr6 =
96+
multi_ptr<void, access::address_space::generic_space,
97+
access::decorated::legacy>(localAccessor);
98+
99+
auto local_ptr7 =
100+
multi_ptr<const void, access::address_space::generic_space,
101+
access::decorated::legacy>(localAccessor);
102+
79103
// Construct extension pointer from accessors.
80104
auto dev_ptr =
81105
multi_ptr<const T,

0 commit comments

Comments
 (0)