Skip to content

Commit 3c8cf0b

Browse files
[SYCL] Use decorated pointer in device_global (#7796)
device_global without device_image_scope currently use an undecorated pointer for its underlying type, but we know that the pointer will be to the global memory space. This commit changes the underlying member for device_global with device_image_scope to be a decorated pointer and changes the get_multi_ptr member function to create the multi_ptr directly from the decorated pointer instead of performing unnecessary address space casts. Signed-off-by: Larsen, Steffen <[email protected]>
1 parent d47b99e commit 3c8cf0b

File tree

1 file changed

+39
-19
lines changed

1 file changed

+39
-19
lines changed

sycl/include/sycl/ext/oneapi/device_global/device_global.hpp

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <sycl/exception.hpp>
1616
#include <sycl/ext/oneapi/device_global/properties.hpp>
1717
#include <sycl/ext/oneapi/properties/properties.hpp>
18+
#include <sycl/pointers.hpp>
1819

1920
#ifdef __SYCL_DEVICE_ONLY__
2021
#define __SYCL_HOST_NOT_SUPPORTED(Op)
@@ -42,9 +43,27 @@ struct HasArrowOperator<
4243
template <typename T, typename PropertyListT, typename = void>
4344
class device_global_base {
4445
protected:
45-
T *usmptr;
46-
T *get_ptr() noexcept { return usmptr; }
47-
const T *get_ptr() const noexcept { return usmptr; }
46+
using pointer_t = typename decorated_global_ptr<T>::pointer;
47+
pointer_t usmptr;
48+
pointer_t get_ptr() noexcept { return usmptr; }
49+
const pointer_t get_ptr() const noexcept { return usmptr; }
50+
51+
public:
52+
template <access::decorated IsDecorated>
53+
multi_ptr<T, access::address_space::global_space, IsDecorated>
54+
get_multi_ptr() noexcept {
55+
__SYCL_HOST_NOT_SUPPORTED("get_multi_ptr()")
56+
return multi_ptr<T, access::address_space::global_space, IsDecorated>{
57+
get_ptr()};
58+
}
59+
60+
template <access::decorated IsDecorated>
61+
multi_ptr<const T, access::address_space::global_space, IsDecorated>
62+
get_multi_ptr() const noexcept {
63+
__SYCL_HOST_NOT_SUPPORTED("get_multi_ptr()")
64+
return multi_ptr<const T, access::address_space::global_space, IsDecorated>{
65+
get_ptr()};
66+
}
4867
};
4968

5069
// Specialization of device_global base class for when device_image_scope is in
@@ -58,6 +77,23 @@ class device_global_base<
5877
T val{};
5978
T *get_ptr() noexcept { return &val; }
6079
const T *get_ptr() const noexcept { return &val; }
80+
81+
public:
82+
template <access::decorated IsDecorated>
83+
multi_ptr<T, access::address_space::global_space, IsDecorated>
84+
get_multi_ptr() noexcept {
85+
__SYCL_HOST_NOT_SUPPORTED("get_multi_ptr()")
86+
return address_space_cast<access::address_space::global_space, IsDecorated,
87+
T>(this->get_ptr());
88+
}
89+
90+
template <access::decorated IsDecorated>
91+
multi_ptr<const T, access::address_space::global_space, IsDecorated>
92+
get_multi_ptr() const noexcept {
93+
__SYCL_HOST_NOT_SUPPORTED("get_multi_ptr()")
94+
return address_space_cast<access::address_space::global_space, IsDecorated,
95+
const T>(this->get_ptr());
96+
}
6197
};
6298
} // namespace detail
6399

@@ -113,22 +149,6 @@ class
113149
device_global &operator=(const device_global &) = delete;
114150
device_global &operator=(const device_global &&) = delete;
115151

116-
template <access::decorated IsDecorated>
117-
multi_ptr<T, access::address_space::global_space, IsDecorated>
118-
get_multi_ptr() noexcept {
119-
__SYCL_HOST_NOT_SUPPORTED("get_multi_ptr()")
120-
return address_space_cast<access::address_space::global_space, IsDecorated>(
121-
this->get_ptr());
122-
}
123-
124-
template <access::decorated IsDecorated>
125-
multi_ptr<const T, access::address_space::global_space, IsDecorated>
126-
get_multi_ptr() const noexcept {
127-
__SYCL_HOST_NOT_SUPPORTED("get_multi_ptr()")
128-
return address_space_cast<access::address_space::global_space, IsDecorated,
129-
const T>(this->get_ptr());
130-
}
131-
132152
T &get() noexcept {
133153
__SYCL_HOST_NOT_SUPPORTED("get()")
134154
return *this->get_ptr();

0 commit comments

Comments
 (0)