Skip to content

[SYCL] Refactor address space casts functionality #15543

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Oct 14, 2024
184 changes: 0 additions & 184 deletions sycl/include/sycl/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -540,190 +540,6 @@ __SPIRV_ATOMICS(__SPIRV_ATOMIC_MINMAX, Max)
#undef __SPIRV_ATOMIC_UNSIGNED
#undef __SPIRV_ATOMIC_XOR

template <typename dataT>
extern __attribute__((opencl_global)) dataT *
__SYCL_GenericCastToPtrExplicit_ToGlobal(void *Ptr) noexcept {
return (__attribute__((opencl_global)) dataT *)
__spirv_GenericCastToPtrExplicit_ToGlobal(
Ptr, __spv::StorageClass::CrossWorkgroup);
}

template <typename dataT>
extern const __attribute__((opencl_global)) dataT *
__SYCL_GenericCastToPtrExplicit_ToGlobal(const void *Ptr) noexcept {
return (const __attribute__((opencl_global)) dataT *)
__spirv_GenericCastToPtrExplicit_ToGlobal(
Ptr, __spv::StorageClass::CrossWorkgroup);
}

template <typename dataT>
extern volatile __attribute__((opencl_global)) dataT *
__SYCL_GenericCastToPtrExplicit_ToGlobal(volatile void *Ptr) noexcept {
return (volatile __attribute__((opencl_global)) dataT *)
__spirv_GenericCastToPtrExplicit_ToGlobal(
Ptr, __spv::StorageClass::CrossWorkgroup);
}

template <typename dataT>
extern const volatile __attribute__((opencl_global)) dataT *
__SYCL_GenericCastToPtrExplicit_ToGlobal(const volatile void *Ptr) noexcept {
return (const volatile __attribute__((opencl_global)) dataT *)
__spirv_GenericCastToPtrExplicit_ToGlobal(
Ptr, __spv::StorageClass::CrossWorkgroup);
}

template <typename dataT>
extern __attribute__((opencl_local)) dataT *
__SYCL_GenericCastToPtrExplicit_ToLocal(void *Ptr) noexcept {
return (__attribute__((opencl_local)) dataT *)
__spirv_GenericCastToPtrExplicit_ToLocal(Ptr,
__spv::StorageClass::Workgroup);
}

template <typename dataT>
extern const __attribute__((opencl_local)) dataT *
__SYCL_GenericCastToPtrExplicit_ToLocal(const void *Ptr) noexcept {
return (const __attribute__((opencl_local)) dataT *)
__spirv_GenericCastToPtrExplicit_ToLocal(Ptr,
__spv::StorageClass::Workgroup);
}

template <typename dataT>
extern volatile __attribute__((opencl_local)) dataT *
__SYCL_GenericCastToPtrExplicit_ToLocal(volatile void *Ptr) noexcept {
return (volatile __attribute__((opencl_local)) dataT *)
__spirv_GenericCastToPtrExplicit_ToLocal(Ptr,
__spv::StorageClass::Workgroup);
}

template <typename dataT>
extern const volatile __attribute__((opencl_local)) dataT *
__SYCL_GenericCastToPtrExplicit_ToLocal(const volatile void *Ptr) noexcept {
return (const volatile __attribute__((opencl_local)) dataT *)
__spirv_GenericCastToPtrExplicit_ToLocal(Ptr,
__spv::StorageClass::Workgroup);
}

template <typename dataT>
extern __attribute__((opencl_private)) dataT *
__SYCL_GenericCastToPtrExplicit_ToPrivate(void *Ptr) noexcept {
return (__attribute__((opencl_private)) dataT *)
__spirv_GenericCastToPtrExplicit_ToPrivate(Ptr,
__spv::StorageClass::Function);
}

template <typename dataT>
extern const __attribute__((opencl_private)) dataT *
__SYCL_GenericCastToPtrExplicit_ToPrivate(const void *Ptr) noexcept {
return (const __attribute__((opencl_private)) dataT *)
__spirv_GenericCastToPtrExplicit_ToPrivate(Ptr,
__spv::StorageClass::Function);
}

template <typename dataT>
extern volatile __attribute__((opencl_private)) dataT *
__SYCL_GenericCastToPtrExplicit_ToPrivate(volatile void *Ptr) noexcept {
return (volatile __attribute__((opencl_private)) dataT *)
__spirv_GenericCastToPtrExplicit_ToPrivate(Ptr,
__spv::StorageClass::Function);
}

template <typename dataT>
extern const volatile __attribute__((opencl_private)) dataT *
__SYCL_GenericCastToPtrExplicit_ToPrivate(const volatile void *Ptr) noexcept {
return (const volatile __attribute__((opencl_private)) dataT *)
__spirv_GenericCastToPtrExplicit_ToPrivate(Ptr,
__spv::StorageClass::Function);
}

template <typename dataT>
extern __attribute__((opencl_global)) dataT *
__SYCL_GenericCastToPtr_ToGlobal(void *Ptr) noexcept {
return (__attribute__((opencl_global)) dataT *)
__spirv_GenericCastToPtr_ToGlobal(Ptr,
__spv::StorageClass::CrossWorkgroup);
}

template <typename dataT>
extern const __attribute__((opencl_global)) dataT *
__SYCL_GenericCastToPtr_ToGlobal(const void *Ptr) noexcept {
return (const __attribute__((opencl_global)) dataT *)
__spirv_GenericCastToPtr_ToGlobal(Ptr,
__spv::StorageClass::CrossWorkgroup);
}

template <typename dataT>
extern volatile __attribute__((opencl_global)) dataT *
__SYCL_GenericCastToPtr_ToGlobal(volatile void *Ptr) noexcept {
return (volatile __attribute__((opencl_global)) dataT *)
__spirv_GenericCastToPtr_ToGlobal(Ptr,
__spv::StorageClass::CrossWorkgroup);
}

template <typename dataT>
extern const volatile __attribute__((opencl_global)) dataT *
__SYCL_GenericCastToPtr_ToGlobal(const volatile void *Ptr) noexcept {
return (const volatile __attribute__((opencl_global)) dataT *)
__spirv_GenericCastToPtr_ToGlobal(Ptr,
__spv::StorageClass::CrossWorkgroup);
}

template <typename dataT>
extern __attribute__((opencl_local)) dataT *
__SYCL_GenericCastToPtr_ToLocal(void *Ptr) noexcept {
return (__attribute__((opencl_local)) dataT *)
__spirv_GenericCastToPtr_ToLocal(Ptr, __spv::StorageClass::Workgroup);
}

template <typename dataT>
extern const __attribute__((opencl_local)) dataT *
__SYCL_GenericCastToPtr_ToLocal(const void *Ptr) noexcept {
return (const __attribute__((opencl_local)) dataT *)
__spirv_GenericCastToPtr_ToLocal(Ptr, __spv::StorageClass::Workgroup);
}

template <typename dataT>
extern volatile __attribute__((opencl_local)) dataT *
__SYCL_GenericCastToPtr_ToLocal(volatile void *Ptr) noexcept {
return (volatile __attribute__((opencl_local)) dataT *)
__spirv_GenericCastToPtr_ToLocal(Ptr, __spv::StorageClass::Workgroup);
}

template <typename dataT>
extern const volatile __attribute__((opencl_local)) dataT *
__SYCL_GenericCastToPtr_ToLocal(const volatile void *Ptr) noexcept {
return (const volatile __attribute__((opencl_local)) dataT *)
__spirv_GenericCastToPtr_ToLocal(Ptr, __spv::StorageClass::Workgroup);
}

template <typename dataT>
extern __attribute__((opencl_private)) dataT *
__SYCL_GenericCastToPtr_ToPrivate(void *Ptr) noexcept {
return (__attribute__((opencl_private)) dataT *)
__spirv_GenericCastToPtr_ToPrivate(Ptr, __spv::StorageClass::Function);
}

template <typename dataT>
extern const __attribute__((opencl_private)) dataT *
__SYCL_GenericCastToPtr_ToPrivate(const void *Ptr) noexcept {
return (const __attribute__((opencl_private)) dataT *)
__spirv_GenericCastToPtr_ToPrivate(Ptr, __spv::StorageClass::Function);
}

template <typename dataT>
extern volatile __attribute__((opencl_private)) dataT *
__SYCL_GenericCastToPtr_ToPrivate(volatile void *Ptr) noexcept {
return (volatile __attribute__((opencl_private)) dataT *)
__spirv_GenericCastToPtr_ToPrivate(Ptr, __spv::StorageClass::Function);
}

template <typename dataT>
extern const volatile __attribute__((opencl_private)) dataT *
__SYCL_GenericCastToPtr_ToPrivate(const volatile void *Ptr) noexcept {
return (const volatile __attribute__((opencl_private)) dataT *)
__spirv_GenericCastToPtr_ToPrivate(Ptr, __spv::StorageClass::Function);
}

template <typename dataT>
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL dataT
__spirv_SubgroupShuffleINTEL(dataT Data, uint32_t InvocationId) noexcept;
Expand Down
190 changes: 143 additions & 47 deletions sycl/include/sycl/access/access.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,58 +325,154 @@ template <typename T>
using remove_decoration_t = typename remove_decoration<T>::type;

namespace detail {

// Helper function for selecting appropriate casts between address spaces.
template <typename ToT, typename FromT> inline ToT cast_AS(FromT from) {
#ifdef __SYCL_DEVICE_ONLY__
constexpr access::address_space ToAS = deduce_AS<ToT>::value;
constexpr access::address_space FromAS = deduce_AS<FromT>::value;
if constexpr (FromAS == access::address_space::generic_space) {
#if defined(__NVPTX__) || defined(__AMDGCN__) || defined(__SYCL_NATIVE_CPU__)
// TODO: NVPTX and AMDGCN backends do not currently support the
// __spirv_GenericCastToPtrExplicit_* builtins, so to work around this
// we do C-style casting. This may produce warnings when targetting
// these backends.
return (ToT)from;
inline constexpr bool
address_space_cast_is_possible(access::address_space Src,
access::address_space Dst) {
// constant_space is unique and is not interchangeable with any other.
auto constant_space = access::address_space::constant_space;
if (Src == constant_space || Dst == constant_space)
return Src == Dst;

auto generic_space = access::address_space::generic_space;
if (Src == Dst || Src == generic_space || Dst == generic_space)
return true;

// global_host/global_device could be casted to/from global
auto global_space = access::address_space::global_space;
auto global_device = access::address_space::ext_intel_global_device_space;
auto global_host = access::address_space::ext_intel_global_host_space;

if (Src == global_space || Dst == global_space) {
auto Other = Src == global_space ? Dst : Src;
if (Other == global_device || Other == global_host)
return true;
}

// No more compatible combinations.
return false;
}

template <access::address_space Space, typename ElementType>
auto static_address_cast(ElementType *Ptr) {
constexpr auto generic_space = access::address_space::generic_space;
constexpr auto global_space = access::address_space::global_space;
constexpr auto local_space = access::address_space::local_space;
constexpr auto private_space = access::address_space::private_space;
constexpr auto global_device =
access::address_space::ext_intel_global_device_space;
constexpr auto global_host =
access::address_space::ext_intel_global_host_space;

constexpr auto SrcAS = deduce_AS<ElementType *>::value;
static_assert(address_space_cast_is_possible(SrcAS, Space));

using dst_type = typename DecoratedType<
std::remove_pointer_t<remove_decoration_t<ElementType *>>, Space>::type *;

// Note: reinterpret_cast isn't enough for some of the casts between different
// address spaces, use C-style cast instead.
#if !defined(__SPIR__)
return (dst_type)Ptr;
#else
using ToElemT = std::remove_pointer_t<remove_decoration_t<ToT>>;
if constexpr (ToAS == access::address_space::global_space)
return __SYCL_GenericCastToPtrExplicit_ToGlobal<ToElemT>(from);
else if constexpr (ToAS == access::address_space::local_space)
return __SYCL_GenericCastToPtrExplicit_ToLocal<ToElemT>(from);
else if constexpr (ToAS == access::address_space::private_space)
return __SYCL_GenericCastToPtrExplicit_ToPrivate<ToElemT>(from);
#ifdef __ENABLE_USM_ADDR_SPACE__
else if constexpr (ToAS == access::address_space::
ext_intel_global_device_space ||
ToAS ==
access::address_space::ext_intel_global_host_space)
// For extended address spaces we do not currently have a SPIR-V
// conversion function, so we do a C-style cast. This may produce
// warnings.
return (ToT)from;
#endif // __ENABLE_USM_ADDR_SPACE__
else
return reinterpret_cast<ToT>(from);
#endif // defined(__NVPTX__) || defined(__AMDGCN__)
} else
#ifdef __ENABLE_USM_ADDR_SPACE__
if constexpr (FromAS == access::address_space::global_space &&
(ToAS ==
access::address_space::ext_intel_global_device_space ||
ToAS ==
access::address_space::ext_intel_global_host_space)) {
// Casting from global address space to the global device and host address
// spaces is allowed.
return (ToT)from;
} else
#endif // __ENABLE_USM_ADDR_SPACE__
#endif // __SYCL_DEVICE_ONLY__
{
return reinterpret_cast<ToT>(from);
if constexpr (SrcAS != generic_space) {
return (dst_type)Ptr;
} else if constexpr (Space == global_space) {
return (dst_type)__spirv_GenericCastToPtr_ToGlobal(
Ptr, __spv::StorageClass::CrossWorkgroup);
} else if constexpr (Space == local_space) {
return (dst_type)__spirv_GenericCastToPtr_ToLocal(
Ptr, __spv::StorageClass::Workgroup);
} else if constexpr (Space == private_space) {
return (dst_type)__spirv_GenericCastToPtr_ToPrivate(
Ptr, __spv::StorageClass::Function);
#if !defined(__ENABLE_USM_ADDR_SPACE__)
} else if constexpr (Space == global_device || Space == global_host) {
// If __ENABLE_USM_ADDR_SPACE__ isn't defined then both
// global_device/global_host are just aliases for global_space.
return (dst_type)__spirv_GenericCastToPtr_ToGlobal(
Ptr, __spv::StorageClass::CrossWorkgroup);
#endif
} else {
return (dst_type)Ptr;
}
#endif
}

// Previous implementation (`castAS`, used in `multi_ptr` ctors among other
// places), used C-style cast instead of a proper dynamic check for some
// backends/spaces. `SupressNotImplementedAssert = true` parameter is emulating
// that previous behavior until the proper support is added for compatibility
// reasons.
template <access::address_space Space, bool SupressNotImplementedAssert = false,
typename ElementType>
auto dynamic_address_cast(ElementType *Ptr) {
constexpr auto generic_space = access::address_space::generic_space;
constexpr auto global_space = access::address_space::global_space;
constexpr auto local_space = access::address_space::local_space;
constexpr auto private_space = access::address_space::private_space;
constexpr auto global_device =
access::address_space::ext_intel_global_device_space;
constexpr auto global_host =
access::address_space::ext_intel_global_host_space;

constexpr auto SrcAS = deduce_AS<ElementType *>::value;
using dst_type = typename DecoratedType<
std::remove_pointer_t<remove_decoration_t<ElementType *>>, Space>::type *;

if constexpr (!address_space_cast_is_possible(SrcAS, Space)) {
return (dst_type) nullptr;
} else if constexpr (Space == generic_space) {
return (dst_type)Ptr;
} else if constexpr (Space == global_space &&
(SrcAS == global_device || SrcAS == global_host)) {
return (dst_type)Ptr;
} else if constexpr (SrcAS == global_space &&
(Space == global_device || Space == global_host)) {
#if defined(__ENABLE_USM_ADDR_SPACE__)
static_assert(SupressNotImplementedAssert || Space != Space,
"Not supported yet!");
return static_address_cast<Space>(Ptr);
#else
// If __ENABLE_USM_ADDR_SPACE__ isn't defined then both
// global_device/global_host are just aliases for global_space.
static_assert(std::is_same_v<dst_type, ElementType *>);
return (dst_type)Ptr;
#endif
#if defined(__SPIR__)
} else if constexpr (Space == global_space) {
return (dst_type)__spirv_GenericCastToPtrExplicit_ToGlobal(
Ptr, __spv::StorageClass::CrossWorkgroup);
} else if constexpr (Space == local_space) {
return (dst_type)__spirv_GenericCastToPtrExplicit_ToLocal(
Ptr, __spv::StorageClass::Workgroup);
} else if constexpr (Space == private_space) {
return (dst_type)__spirv_GenericCastToPtrExplicit_ToPrivate(
Ptr, __spv::StorageClass::Function);
#if !defined(__ENABLE_USM_ADDR_SPACE__)
} else if constexpr (SrcAS == generic_space &&
(Space == global_device || Space == global_host)) {
return (dst_type)__spirv_GenericCastToPtrExplicit_ToGlobal(
Ptr, __spv::StorageClass::CrossWorkgroup);
#endif
#endif
} else {
static_assert(SupressNotImplementedAssert || Space != Space,
"Not supported yet!");
return static_address_cast<Space>(Ptr);
}
}
#else // __SYCL_DEVICE_ONLY__
template <access::address_space Space, typename ElementType>
auto static_address_cast(ElementType *Ptr) {
return Ptr;
}
template <access::address_space Space, bool SupressNotImplementedAssert = false,
typename ElementType>
auto dynamic_address_cast(ElementType *Ptr) {
return Ptr;
}
#endif // __SYCL_DEVICE_ONLY__
} // namespace detail

#undef __OPENCL_GLOBAL_AS__
Expand Down
Loading
Loading