Skip to content

Commit 1237051

Browse files
authored
[SYCL] Implement sycl_ext_oneapi_address_cast (#12382)
- Adds support for OpGenericCastToPtr SPIR-V intrinsic. - Implements static_address_cast using OpGenericCastToPtr. - Implements dynamic_address_cast using OpGenericCastToPtrExplicit. - Adds tests for both new forms of address_cast. --------- Signed-off-by: John Pennycook <[email protected]>
1 parent 8e7995d commit 1237051

File tree

7 files changed

+349
-1
lines changed

7 files changed

+349
-1
lines changed

clang/lib/Sema/SPIRVBuiltins.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,9 @@ foreach Ty = [Void, ConstType<Void>, VolatileType<Void>, VolatileType<ConstType<
844844
def : SPVBuiltin<"GenericCastToPtrExplicit_ToGlobal", [PointerType<Ty, GlobalAS>, PointerType<Ty, DefaultAS>, Int], Attr.Const>;
845845
def : SPVBuiltin<"GenericCastToPtrExplicit_ToLocal", [PointerType<Ty, LocalAS>, PointerType<Ty, DefaultAS>, Int], Attr.Const>;
846846
def : SPVBuiltin<"GenericCastToPtrExplicit_ToPrivate", [PointerType<Ty, PrivateAS>, PointerType<Ty, DefaultAS>, Int], Attr.Const>;
847+
def : SPVBuiltin<"GenericCastToPtr_ToGlobal", [PointerType<Ty, GlobalAS>, PointerType<Ty, DefaultAS>, Int], Attr.Const>;
848+
def : SPVBuiltin<"GenericCastToPtr_ToLocal", [PointerType<Ty, LocalAS>, PointerType<Ty, DefaultAS>, Int], Attr.Const>;
849+
def : SPVBuiltin<"GenericCastToPtr_ToPrivate", [PointerType<Ty, PrivateAS>, PointerType<Ty, DefaultAS>, Int], Attr.Const>;
847850
}
848851

849852
foreach Type = TLFloat.List in {

sycl/include/CL/__spirv/spirv_ops.hpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,94 @@ __SYCL_GenericCastToPtrExplicit_ToPrivate(const volatile void *Ptr) noexcept {
468468
__spv::StorageClass::Function);
469469
}
470470

471+
template <typename dataT>
472+
extern __attribute__((opencl_global)) dataT *
473+
__SYCL_GenericCastToPtr_ToGlobal(void *Ptr) noexcept {
474+
return (__attribute__((opencl_global)) dataT *)
475+
__spirv_GenericCastToPtr_ToGlobal(Ptr,
476+
__spv::StorageClass::CrossWorkgroup);
477+
}
478+
479+
template <typename dataT>
480+
extern const __attribute__((opencl_global)) dataT *
481+
__SYCL_GenericCastToPtr_ToGlobal(const void *Ptr) noexcept {
482+
return (const __attribute__((opencl_global)) dataT *)
483+
__spirv_GenericCastToPtr_ToGlobal(Ptr,
484+
__spv::StorageClass::CrossWorkgroup);
485+
}
486+
487+
template <typename dataT>
488+
extern volatile __attribute__((opencl_global)) dataT *
489+
__SYCL_GenericCastToPtr_ToGlobal(volatile void *Ptr) noexcept {
490+
return (volatile __attribute__((opencl_global)) dataT *)
491+
__spirv_GenericCastToPtr_ToGlobal(Ptr,
492+
__spv::StorageClass::CrossWorkgroup);
493+
}
494+
495+
template <typename dataT>
496+
extern const volatile __attribute__((opencl_global)) dataT *
497+
__SYCL_GenericCastToPtr_ToGlobal(const volatile void *Ptr) noexcept {
498+
return (const volatile __attribute__((opencl_global)) dataT *)
499+
__spirv_GenericCastToPtr_ToGlobal(Ptr,
500+
__spv::StorageClass::CrossWorkgroup);
501+
}
502+
503+
template <typename dataT>
504+
extern __attribute__((opencl_local)) dataT *
505+
__SYCL_GenericCastToPtr_ToLocal(void *Ptr) noexcept {
506+
return (__attribute__((opencl_local)) dataT *)
507+
__spirv_GenericCastToPtr_ToLocal(Ptr, __spv::StorageClass::Workgroup);
508+
}
509+
510+
template <typename dataT>
511+
extern const __attribute__((opencl_local)) dataT *
512+
__SYCL_GenericCastToPtr_ToLocal(const void *Ptr) noexcept {
513+
return (const __attribute__((opencl_local)) dataT *)
514+
__spirv_GenericCastToPtr_ToLocal(Ptr, __spv::StorageClass::Workgroup);
515+
}
516+
517+
template <typename dataT>
518+
extern volatile __attribute__((opencl_local)) dataT *
519+
__SYCL_GenericCastToPtr_ToLocal(volatile void *Ptr) noexcept {
520+
return (volatile __attribute__((opencl_local)) dataT *)
521+
__spirv_GenericCastToPtr_ToLocal(Ptr, __spv::StorageClass::Workgroup);
522+
}
523+
524+
template <typename dataT>
525+
extern const volatile __attribute__((opencl_local)) dataT *
526+
__SYCL_GenericCastToPtr_ToLocal(const volatile void *Ptr) noexcept {
527+
return (const volatile __attribute__((opencl_local)) dataT *)
528+
__spirv_GenericCastToPtr_ToLocal(Ptr, __spv::StorageClass::Workgroup);
529+
}
530+
531+
template <typename dataT>
532+
extern __attribute__((opencl_private)) dataT *
533+
__SYCL_GenericCastToPtr_ToPrivate(void *Ptr) noexcept {
534+
return (__attribute__((opencl_private)) dataT *)
535+
__spirv_GenericCastToPtr_ToPrivate(Ptr, __spv::StorageClass::Function);
536+
}
537+
538+
template <typename dataT>
539+
extern const __attribute__((opencl_private)) dataT *
540+
__SYCL_GenericCastToPtr_ToPrivate(const void *Ptr) noexcept {
541+
return (const __attribute__((opencl_private)) dataT *)
542+
__spirv_GenericCastToPtr_ToPrivate(Ptr, __spv::StorageClass::Function);
543+
}
544+
545+
template <typename dataT>
546+
extern volatile __attribute__((opencl_private)) dataT *
547+
__SYCL_GenericCastToPtr_ToPrivate(volatile void *Ptr) noexcept {
548+
return (volatile __attribute__((opencl_private)) dataT *)
549+
__spirv_GenericCastToPtr_ToPrivate(Ptr, __spv::StorageClass::Function);
550+
}
551+
552+
template <typename dataT>
553+
extern const volatile __attribute__((opencl_private)) dataT *
554+
__SYCL_GenericCastToPtr_ToPrivate(const volatile void *Ptr) noexcept {
555+
return (const volatile __attribute__((opencl_private)) dataT *)
556+
__spirv_GenericCastToPtr_ToPrivate(Ptr, __spv::StorageClass::Function);
557+
}
558+
471559
template <typename dataT>
472560
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL dataT
473561
__spirv_SubgroupShuffleINTEL(dataT Data, uint32_t InvocationId) noexcept;

sycl/include/sycl/detail/spirv.hpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,6 +1176,30 @@ __SYCL_GROUP_COLLECTIVE_OVERLOAD(BitwiseAndKHR)
11761176
__SYCL_GROUP_COLLECTIVE_OVERLOAD(LogicalAndKHR)
11771177
__SYCL_GROUP_COLLECTIVE_OVERLOAD(LogicalOrKHR)
11781178

1179+
template <access::address_space Space, typename T>
1180+
auto GenericCastToPtr(T *Ptr) ->
1181+
typename multi_ptr<T, Space, access::decorated::yes>::pointer {
1182+
if constexpr (Space == access::address_space::global_space) {
1183+
return __SYCL_GenericCastToPtr_ToGlobal<T>(Ptr);
1184+
} else if constexpr (Space == access::address_space::local_space) {
1185+
return __SYCL_GenericCastToPtr_ToLocal<T>(Ptr);
1186+
} else if constexpr (Space == access::address_space::private_space) {
1187+
return __SYCL_GenericCastToPtr_ToPrivate<T>(Ptr);
1188+
}
1189+
}
1190+
1191+
template <access::address_space Space, typename T>
1192+
auto GenericCastToPtrExplicit(T *Ptr) ->
1193+
typename multi_ptr<T, Space, access::decorated::yes>::pointer {
1194+
if constexpr (Space == access::address_space::global_space) {
1195+
return __SYCL_GenericCastToPtrExplicit_ToGlobal<T>(Ptr);
1196+
} else if constexpr (Space == access::address_space::local_space) {
1197+
return __SYCL_GenericCastToPtrExplicit_ToLocal<T>(Ptr);
1198+
} else if constexpr (Space == access::address_space::private_space) {
1199+
return __SYCL_GenericCastToPtrExplicit_ToPrivate<T>(Ptr);
1200+
}
1201+
}
1202+
11791203
} // namespace spirv
11801204
} // namespace detail
11811205
} // namespace _V1
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
//==----------- address_cast.hpp - sycl_ext_oneapi_address_cast ------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#pragma once
10+
#include <sycl/multi_ptr.hpp>
11+
12+
namespace sycl {
13+
inline namespace _V1 {
14+
namespace ext {
15+
namespace oneapi {
16+
namespace experimental {
17+
18+
template <access::address_space Space, access::decorated DecorateAddress,
19+
typename ElementType>
20+
multi_ptr<ElementType, Space, DecorateAddress>
21+
static_address_cast(ElementType *Ptr) {
22+
#ifdef __SYCL_DEVICE_ONLY__
23+
auto CastPtr = sycl::detail::spirv::GenericCastToPtr<Space>(Ptr);
24+
return multi_ptr<ElementType, Space, DecorateAddress>(CastPtr);
25+
#else
26+
return multi_ptr<ElementType, Space, DecorateAddress>(Ptr);
27+
#endif
28+
}
29+
30+
template <access::address_space Space, access::decorated DecorateAddress,
31+
typename ElementType>
32+
multi_ptr<ElementType, Space, DecorateAddress>
33+
dynamic_address_cast(ElementType *Ptr) {
34+
#ifdef __SYCL_DEVICE_ONLY__
35+
auto CastPtr = sycl::detail::spirv::GenericCastToPtrExplicit<Space>(Ptr);
36+
return multi_ptr<ElementType, Space, DecorateAddress>(CastPtr);
37+
#else
38+
return multi_ptr<ElementType, Space, DecorateAddress>(Ptr);
39+
#endif
40+
}
41+
42+
} // namespace experimental
43+
} // namespace oneapi
44+
} // namespace ext
45+
} // namespace _V1
46+
} // namespace sycl

sycl/include/sycl/sycl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
#include <sycl/ext/oneapi/bindless_images.hpp>
7575
#include <sycl/ext/oneapi/device_global/device_global.hpp>
7676
#include <sycl/ext/oneapi/device_global/properties.hpp>
77+
#include <sycl/ext/oneapi/experimental/address_cast.hpp>
7778
#include <sycl/ext/oneapi/experimental/annotated_arg/annotated_arg.hpp>
7879
#include <sycl/ext/oneapi/experimental/annotated_ptr/annotated_ptr.hpp>
7980
#include <sycl/ext/oneapi/experimental/annotated_usm/alloc_device.hpp>
@@ -100,7 +101,6 @@
100101
#include <sycl/ext/oneapi/sub_group.hpp>
101102
#include <sycl/ext/oneapi/sub_group_mask.hpp>
102103
#include <sycl/ext/oneapi/weak_object.hpp>
103-
104104
#if !defined(SYCL2020_CONFORMANT_APIS) && \
105105
!defined(__INTEL_PREVIEW_BREAKING_CHANGES)
106106
// We used to include those and some code might be reliant on that.
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
//==--------- dynamic_address_cast.cpp - dynamic address_cast test ---------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
// Issue with OpenCL CPU runtime implementation of OpGenericCastToPtrExplicit
10+
// OpGenericCastToPtr* intrinsics not implemented on AMD or NVIDIA
11+
// UNSUPPORTED: cpu, hip, cuda
12+
// RUN: %{build} -o %t.out
13+
// RUN: %{run} %t.out
14+
#include <sycl/sycl.hpp>
15+
16+
int main() {
17+
18+
sycl::queue Queue;
19+
20+
sycl::range<1> NItems{1};
21+
22+
sycl::buffer<int, 1> GlobalBuffer{NItems};
23+
sycl::buffer<bool, 1> ResultBuffer{NItems};
24+
25+
Queue
26+
.submit([&](sycl::handler &cgh) {
27+
auto GlobalAccessor =
28+
GlobalBuffer.get_access<sycl::access::mode::read_write>(cgh);
29+
auto LocalAccessor = sycl::local_accessor<int>(1, cgh);
30+
auto ResultAccessor =
31+
ResultBuffer.get_access<sycl::access::mode::write>(cgh);
32+
cgh.parallel_for<class Kernel>(
33+
sycl::nd_range<1>(NItems, 1), [=](sycl::nd_item<1> Item) {
34+
bool Success = true;
35+
size_t Index = Item.get_global_id(0);
36+
37+
int *RawGlobalPointer = &GlobalAccessor[Index];
38+
{
39+
auto GlobalPointer =
40+
sycl::ext::oneapi::experimental::dynamic_address_cast<
41+
sycl::access::address_space::global_space,
42+
sycl::access::decorated::no>(RawGlobalPointer);
43+
auto LocalPointer =
44+
sycl::ext::oneapi::experimental::dynamic_address_cast<
45+
sycl::access::address_space::local_space,
46+
sycl::access::decorated::no>(RawGlobalPointer);
47+
auto PrivatePointer =
48+
sycl::ext::oneapi::experimental::dynamic_address_cast<
49+
sycl::access::address_space::private_space,
50+
sycl::access::decorated::no>(RawGlobalPointer);
51+
Success &= reinterpret_cast<size_t>(RawGlobalPointer) ==
52+
reinterpret_cast<size_t>(GlobalPointer.get_raw());
53+
Success &= LocalPointer.get_raw() == nullptr;
54+
Success &= PrivatePointer.get_raw() == nullptr;
55+
}
56+
57+
int *RawLocalPointer = &LocalAccessor[0];
58+
{
59+
auto GlobalPointer =
60+
sycl::ext::oneapi::experimental::dynamic_address_cast<
61+
sycl::access::address_space::global_space,
62+
sycl::access::decorated::no>(RawLocalPointer);
63+
auto LocalPointer =
64+
sycl::ext::oneapi::experimental::dynamic_address_cast<
65+
sycl::access::address_space::local_space,
66+
sycl::access::decorated::no>(RawLocalPointer);
67+
auto PrivatePointer =
68+
sycl::ext::oneapi::experimental::dynamic_address_cast<
69+
sycl::access::address_space::private_space,
70+
sycl::access::decorated::no>(RawLocalPointer);
71+
Success &= GlobalPointer.get_raw() == nullptr;
72+
Success &= reinterpret_cast<size_t>(RawLocalPointer) ==
73+
reinterpret_cast<size_t>(LocalPointer.get_raw());
74+
Success &= PrivatePointer.get_raw() == nullptr;
75+
}
76+
77+
int PrivateVariable = 0;
78+
int *RawPrivatePointer = &PrivateVariable;
79+
{
80+
auto GlobalPointer =
81+
sycl::ext::oneapi::experimental::dynamic_address_cast<
82+
sycl::access::address_space::global_space,
83+
sycl::access::decorated::no>(RawPrivatePointer);
84+
auto LocalPointer =
85+
sycl::ext::oneapi::experimental::dynamic_address_cast<
86+
sycl::access::address_space::local_space,
87+
sycl::access::decorated::no>(RawPrivatePointer);
88+
auto PrivatePointer =
89+
sycl::ext::oneapi::experimental::dynamic_address_cast<
90+
sycl::access::address_space::private_space,
91+
sycl::access::decorated::no>(RawPrivatePointer);
92+
Success &= GlobalPointer.get_raw() == nullptr;
93+
Success &= LocalPointer.get_raw() == nullptr;
94+
Success &= reinterpret_cast<size_t>(RawPrivatePointer) ==
95+
reinterpret_cast<size_t>(PrivatePointer.get_raw());
96+
}
97+
98+
ResultAccessor[Index] = Success;
99+
});
100+
})
101+
.wait();
102+
103+
bool Success = true;
104+
{
105+
auto ResultAccessor = sycl::host_accessor(ResultBuffer);
106+
for (int i = 0; i < NItems.size(); ++i) {
107+
Success &= ResultAccessor[i];
108+
};
109+
}
110+
111+
return (Success) ? 0 : -1;
112+
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
//==---------- static_address_cast.cpp - static address_cast test ----------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
// OpGenericCastToPtr* intrinsics not implemented on AMD or NVIDIA
10+
// UNSUPPORTED: hip, cuda
11+
// RUN: %{build} -o %t.out
12+
// RUN: %{run} %t.out
13+
#include <sycl/sycl.hpp>
14+
15+
int main() {
16+
17+
sycl::queue Queue;
18+
19+
sycl::range<1> NItems{1};
20+
21+
sycl::buffer<int, 1> GlobalBuffer{NItems};
22+
sycl::buffer<bool, 1> ResultBuffer{NItems};
23+
24+
Queue
25+
.submit([&](sycl::handler &cgh) {
26+
auto GlobalAccessor =
27+
GlobalBuffer.get_access<sycl::access::mode::read_write>(cgh);
28+
auto LocalAccessor = sycl::local_accessor<int>(1, cgh);
29+
auto ResultAccessor =
30+
ResultBuffer.get_access<sycl::access::mode::write>(cgh);
31+
cgh.parallel_for<class Kernel>(
32+
sycl::nd_range<1>(NItems, 1), [=](sycl::nd_item<1> Item) {
33+
bool Success = true;
34+
size_t Index = Item.get_global_id(0);
35+
36+
int *RawGlobalPointer = &GlobalAccessor[Index];
37+
auto GlobalPointer =
38+
sycl::ext::oneapi::experimental::static_address_cast<
39+
sycl::access::address_space::global_space,
40+
sycl::access::decorated::no>(RawGlobalPointer);
41+
Success &= reinterpret_cast<size_t>(RawGlobalPointer) ==
42+
reinterpret_cast<size_t>(GlobalPointer.get_raw());
43+
44+
int *RawLocalPointer = &LocalAccessor[0];
45+
auto LocalPointer =
46+
sycl::ext::oneapi::experimental::static_address_cast<
47+
sycl::access::address_space::local_space,
48+
sycl::access::decorated::no>(RawLocalPointer);
49+
Success &= reinterpret_cast<size_t>(RawLocalPointer) ==
50+
reinterpret_cast<size_t>(LocalPointer.get_raw());
51+
52+
int PrivateVariable = 0;
53+
int *RawPrivatePointer = &PrivateVariable;
54+
auto PrivatePointer =
55+
sycl::ext::oneapi::experimental::static_address_cast<
56+
sycl::access::address_space::private_space,
57+
sycl::access::decorated::no>(RawPrivatePointer);
58+
Success &= reinterpret_cast<size_t>(RawPrivatePointer) ==
59+
reinterpret_cast<size_t>(PrivatePointer.get_raw());
60+
61+
ResultAccessor[Index] = Success;
62+
});
63+
})
64+
.wait();
65+
66+
bool Success = true;
67+
{
68+
auto ResultAccessor = sycl::host_accessor(ResultBuffer);
69+
for (int i = 0; i < NItems.size(); ++i) {
70+
Success &= ResultAccessor[i];
71+
};
72+
}
73+
74+
return (Success) ? 0 : -1;
75+
}

0 commit comments

Comments
 (0)