Skip to content

[SYCL] Use decorated pointer in device_global #7796

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 39 additions & 19 deletions sycl/include/sycl/ext/oneapi/device_global/device_global.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <sycl/exception.hpp>
#include <sycl/ext/oneapi/device_global/properties.hpp>
#include <sycl/ext/oneapi/properties/properties.hpp>
#include <sycl/pointers.hpp>

#ifdef __SYCL_DEVICE_ONLY__
#define __SYCL_HOST_NOT_SUPPORTED(Op)
Expand Down Expand Up @@ -42,9 +43,27 @@ struct HasArrowOperator<
template <typename T, typename PropertyListT, typename = void>
class device_global_base {
protected:
T *usmptr;
T *get_ptr() noexcept { return usmptr; }
const T *get_ptr() const noexcept { return usmptr; }
using pointer_t = typename decorated_global_ptr<T>::pointer;
pointer_t usmptr;
pointer_t get_ptr() noexcept { return usmptr; }
const pointer_t get_ptr() const noexcept { return usmptr; }

public:
template <access::decorated IsDecorated>
multi_ptr<T, access::address_space::global_space, IsDecorated>
get_multi_ptr() noexcept {
__SYCL_HOST_NOT_SUPPORTED("get_multi_ptr()")
return multi_ptr<T, access::address_space::global_space, IsDecorated>{
get_ptr()};
}

template <access::decorated IsDecorated>
multi_ptr<const T, access::address_space::global_space, IsDecorated>
get_multi_ptr() const noexcept {
__SYCL_HOST_NOT_SUPPORTED("get_multi_ptr()")
return multi_ptr<const T, access::address_space::global_space, IsDecorated>{
get_ptr()};
}
};

// Specialization of device_global base class for when device_image_scope is in
Expand All @@ -58,6 +77,23 @@ class device_global_base<
T val{};
T *get_ptr() noexcept { return &val; }
const T *get_ptr() const noexcept { return &val; }

public:
template <access::decorated IsDecorated>
multi_ptr<T, access::address_space::global_space, IsDecorated>
get_multi_ptr() noexcept {
__SYCL_HOST_NOT_SUPPORTED("get_multi_ptr()")
return address_space_cast<access::address_space::global_space, IsDecorated,
T>(this->get_ptr());
}

template <access::decorated IsDecorated>
multi_ptr<const T, access::address_space::global_space, IsDecorated>
get_multi_ptr() const noexcept {
__SYCL_HOST_NOT_SUPPORTED("get_multi_ptr()")
return address_space_cast<access::address_space::global_space, IsDecorated,
const T>(this->get_ptr());
}
};
} // namespace detail

Expand Down Expand Up @@ -113,22 +149,6 @@ class
device_global &operator=(const device_global &) = delete;
device_global &operator=(const device_global &&) = delete;

template <access::decorated IsDecorated>
multi_ptr<T, access::address_space::global_space, IsDecorated>
get_multi_ptr() noexcept {
__SYCL_HOST_NOT_SUPPORTED("get_multi_ptr()")
return address_space_cast<access::address_space::global_space, IsDecorated>(
this->get_ptr());
}

template <access::decorated IsDecorated>
multi_ptr<const T, access::address_space::global_space, IsDecorated>
get_multi_ptr() const noexcept {
__SYCL_HOST_NOT_SUPPORTED("get_multi_ptr()")
return address_space_cast<access::address_space::global_space, IsDecorated,
const T>(this->get_ptr());
}

T &get() noexcept {
__SYCL_HOST_NOT_SUPPORTED("get()")
return *this->get_ptr();
Expand Down