Skip to content

Commit 3fdbfeb

Browse files
[SYCL] Fix return type of the accessor::get_pointer and local_accessor::get_pointer. (#8493)
* accessor::get_pointer and local_accessor::get_pointer return `std::add_pointer_t<value_type>` * Modifies multi_ptr ctors accepting local_accessor. * Improves existing test to check the return type of get_pointer. --------- Co-authored-by: Steffen Larsen <[email protected]>
1 parent 1396da2 commit 3fdbfeb

File tree

3 files changed

+46
-13
lines changed

3 files changed

+46
-13
lines changed

sycl/include/sycl/accessor.hpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2087,8 +2087,9 @@ class __SYCL_EBO __SYCL_SPECIAL_CLASS __SYCL_TYPE(accessor) accessor :
20872087
}
20882088

20892089
template <access::target AccessTarget_ = AccessTarget,
2090-
typename = detail::enable_if_t<AccessTarget_ ==
2091-
access::target::host_buffer>>
2090+
typename = detail::enable_if_t<
2091+
(AccessTarget_ == access::target::host_buffer) ||
2092+
(AccessTarget_ == access::target::host_task)>>
20922093
#if SYCL_LANGUAGE_VERSION >= 202001
20932094
std::add_pointer_t<value_type> get_pointer() const noexcept
20942095
#else
@@ -2663,10 +2664,6 @@ class __SYCL_SPECIAL_CLASS local_accessor_base :
26632664
return AccessorSubscript<Dims - 1>(*this, Index);
26642665
}
26652666

2666-
local_ptr<DataT> get_pointer() const {
2667-
return local_ptr<DataT>(getQualifiedPtr());
2668-
}
2669-
26702667
bool operator==(const local_accessor_base &Rhs) const {
26712668
return impl == Rhs.impl;
26722669
}
@@ -2695,6 +2692,11 @@ class __SYCL_EBO __SYCL_SPECIAL_CLASS accessor<
26952692
// Use base classes constructors
26962693
using local_acc::local_acc;
26972694

2695+
public:
2696+
local_ptr<DataT> get_pointer() const {
2697+
return local_ptr<DataT>(local_acc::getQualifiedPtr());
2698+
}
2699+
26982700
#ifdef __SYCL_DEVICE_ONLY__
26992701

27002702
// __init needs to be defined within the class not through inheritance.
@@ -2801,6 +2803,10 @@ class __SYCL_EBO __SYCL_SPECIAL_CLASS __SYCL_TYPE(local_accessor) local_accessor
28012803
return const_reverse_iterator(begin());
28022804
}
28032805

2806+
std::add_pointer_t<value_type> get_pointer() const noexcept {
2807+
return std::add_pointer_t<value_type>(local_acc::getQualifiedPtr());
2808+
}
2809+
28042810
template <access::decorated IsDecorated>
28052811
accessor_ptr<IsDecorated> get_multi_ptr() const noexcept {
28062812
return accessor_ptr<IsDecorated>(local_acc::getQualifiedPtr());

sycl/include/sycl/multi_ptr.hpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ class multi_ptr {
149149
(Space == access::address_space::generic_space ||
150150
Space == access::address_space::local_space)>>
151151
multi_ptr(local_accessor<ElementType, Dimensions> Accessor)
152-
: multi_ptr(Accessor.get_pointer().get()) {}
152+
: m_Pointer(detail::cast_AS<decorated_type *>(Accessor.get_pointer())) {}
153153

154154
// The following constructors are necessary to create multi_ptr<const
155155
// ElementType, Space, DecorateAddress> from accessor<ElementType, ...>.
@@ -210,7 +210,7 @@ class multi_ptr {
210210
multi_ptr(local_accessor<typename detail::remove_const_t<RelayElementType>,
211211
Dimensions>
212212
Accessor)
213-
: multi_ptr(Accessor.get_pointer().get()) {}
213+
: m_Pointer(detail::cast_AS<decorated_type *>(Accessor.get_pointer())) {}
214214

215215
// Assignment and access operators
216216
multi_ptr &operator=(const multi_ptr &) = default;
@@ -465,7 +465,7 @@ class multi_ptr<const void, Space, DecorateAddress> {
465465
typename = typename detail::enable_if_t<
466466
RelaySpace == Space && Space == access::address_space::local_space>>
467467
multi_ptr(local_accessor<ElementType, Dimensions> Accessor)
468-
: multi_ptr(Accessor.get_pointer().get()) {}
468+
: m_Pointer(detail::cast_AS<decorated_type *>(Accessor.get_pointer())) {}
469469

470470
// Assignment operators
471471
multi_ptr &operator=(const multi_ptr &) = default;
@@ -591,7 +591,7 @@ class multi_ptr<void, Space, DecorateAddress> {
591591
typename = typename detail::enable_if_t<
592592
RelaySpace == Space && Space == access::address_space::local_space>>
593593
multi_ptr(local_accessor<ElementType, Dimensions> Accessor)
594-
: multi_ptr(Accessor.get_pointer().get()) {}
594+
: m_Pointer(detail::cast_AS<decorated_type *>(Accessor.get_pointer())) {}
595595

596596
// Assignment operators
597597
multi_ptr &operator=(const multi_ptr &) = default;
@@ -848,7 +848,7 @@ class multi_ptr<ElementType, Space, access::decorated::legacy> {
848848
std::is_const<ET>::value && std::is_same<ET, ElementType>::value>>
849849
multi_ptr(
850850
local_accessor<typename detail::remove_const_t<ET>, dimensions> Accessor)
851-
: multi_ptr(Accessor.get_pointer()) {}
851+
: m_Pointer(detail::cast_AS<pointer_t>(Accessor.get_pointer())) {}
852852

853853
// Only if Space == constant_space and element type is const
854854
template <
@@ -1089,7 +1089,7 @@ class multi_ptr<void, Space, access::decorated::legacy> {
10891089
_Space == Space && (Space == access::address_space::generic_space ||
10901090
Space == access::address_space::local_space)>>
10911091
multi_ptr(local_accessor<ElementType, dimensions> Accessor)
1092-
: multi_ptr(Accessor.get_pointer()) {}
1092+
: m_Pointer(detail::cast_AS<pointer_t>(Accessor.get_pointer())) {}
10931093

10941094
// Only if Space == constant_space
10951095
template <
@@ -1232,7 +1232,7 @@ class multi_ptr<const void, Space, access::decorated::legacy> {
12321232
_Space == Space && (Space == access::address_space::generic_space ||
12331233
Space == access::address_space::local_space)>>
12341234
multi_ptr(local_accessor<ElementType, dimensions> Accessor)
1235-
: multi_ptr(Accessor.get_pointer()) {}
1235+
: m_Pointer(detail::cast_AS<pointer_t>(Accessor.get_pointer())) {}
12361236

12371237
// Only if Space == constant_space
12381238
template <
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -fsyntax-only
2+
3+
#include <cassert>
4+
#include <sycl/sycl.hpp>
5+
#include <type_traits>
6+
7+
using namespace sycl;
8+
9+
constexpr static int size = 1;
10+
11+
void test_get_multi_ptr(handler &cgh, buffer<int, size> &buffer) {
12+
using target_local_accessor_t =
13+
accessor<int, size, access::mode::read_write, access::target::local>;
14+
using local_accessor_t = local_accessor<int, size>;
15+
16+
auto acc = buffer.get_access<access_mode::read_write, target::host_task>(cgh);
17+
auto target_local_acc = target_local_accessor_t({size}, cgh);
18+
auto local_acc = local_accessor_t({size}, cgh);
19+
20+
auto acc_ptr = acc.get_pointer();
21+
auto target_local_ptr = target_local_acc.get_pointer();
22+
auto local_pointer = local_acc.get_pointer();
23+
static_assert(std::is_same_v<decltype(acc_ptr), std::add_pointer_t<int>>);
24+
static_assert(std::is_same_v<decltype(target_local_ptr), local_ptr<int>>);
25+
static_assert(
26+
std::is_same_v<decltype(local_pointer), std::add_pointer_t<int>>);
27+
}

0 commit comments

Comments
 (0)