Skip to content

Commit a541742

Browse files
committed
Add implementation of sycl_ext_oneapi_address_cast
- Adds support for OpGenericCastToPtr SPIR-V intrinsic. - Implements static_address_cast using OpGenericCastToPtr. - Implements dynamic_address_cast using OpGenericCastToPtrExplicit. - Adds tests for both new forms of address_cast. Signed-off-by: John Pennycook <[email protected]>
1 parent d3a5f1d commit a541742

File tree

7 files changed

+346
-1
lines changed

7 files changed

+346
-1
lines changed

clang/lib/Sema/SPIRVBuiltins.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,9 @@ foreach Ty = [Void, ConstType<Void>, VolatileType<Void>, VolatileType<ConstType<
844844
def : SPVBuiltin<"GenericCastToPtrExplicit_ToGlobal", [PointerType<Ty, GlobalAS>, PointerType<Ty, DefaultAS>, Int], Attr.Const>;
845845
def : SPVBuiltin<"GenericCastToPtrExplicit_ToLocal", [PointerType<Ty, LocalAS>, PointerType<Ty, DefaultAS>, Int], Attr.Const>;
846846
def : SPVBuiltin<"GenericCastToPtrExplicit_ToPrivate", [PointerType<Ty, PrivateAS>, PointerType<Ty, DefaultAS>, Int], Attr.Const>;
847+
def : SPVBuiltin<"GenericCastToPtr_ToGlobal", [PointerType<Ty, GlobalAS>, PointerType<Ty, DefaultAS>, Int], Attr.Const>;
848+
def : SPVBuiltin<"GenericCastToPtr_ToLocal", [PointerType<Ty, LocalAS>, PointerType<Ty, DefaultAS>, Int], Attr.Const>;
849+
def : SPVBuiltin<"GenericCastToPtr_ToPrivate", [PointerType<Ty, PrivateAS>, PointerType<Ty, DefaultAS>, Int], Attr.Const>;
847850
}
848851

849852
foreach Type = TLFloat.List in {

sycl/include/CL/__spirv/spirv_ops.hpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,94 @@ __SYCL_GenericCastToPtrExplicit_ToPrivate(const volatile void *Ptr) noexcept {
467467
__spv::StorageClass::Function);
468468
}
469469

470+
template <typename dataT>
471+
extern __attribute__((opencl_global)) dataT *
472+
__SYCL_GenericCastToPtr_ToGlobal(void *Ptr) noexcept {
473+
return (__attribute__((opencl_global)) dataT *)
474+
__spirv_GenericCastToPtr_ToGlobal(Ptr,
475+
__spv::StorageClass::CrossWorkgroup);
476+
}
477+
478+
template <typename dataT>
479+
extern const __attribute__((opencl_global)) dataT *
480+
__SYCL_GenericCastToPtr_ToGlobal(const void *Ptr) noexcept {
481+
return (const __attribute__((opencl_global)) dataT *)
482+
__spirv_GenericCastToPtr_ToGlobal(Ptr,
483+
__spv::StorageClass::CrossWorkgroup);
484+
}
485+
486+
template <typename dataT>
487+
extern volatile __attribute__((opencl_global)) dataT *
488+
__SYCL_GenericCastToPtr_ToGlobal(volatile void *Ptr) noexcept {
489+
return (volatile __attribute__((opencl_global)) dataT *)
490+
__spirv_GenericCastToPtr_ToGlobal(Ptr,
491+
__spv::StorageClass::CrossWorkgroup);
492+
}
493+
494+
template <typename dataT>
495+
extern const volatile __attribute__((opencl_global)) dataT *
496+
__SYCL_GenericCastToPtr_ToGlobal(const volatile void *Ptr) noexcept {
497+
return (const volatile __attribute__((opencl_global)) dataT *)
498+
__spirv_GenericCastToPtr_ToGlobal(Ptr,
499+
__spv::StorageClass::CrossWorkgroup);
500+
}
501+
502+
template <typename dataT>
503+
extern __attribute__((opencl_local)) dataT *
504+
__SYCL_GenericCastToPtr_ToLocal(void *Ptr) noexcept {
505+
return (__attribute__((opencl_local)) dataT *)
506+
__spirv_GenericCastToPtr_ToLocal(Ptr, __spv::StorageClass::Workgroup);
507+
}
508+
509+
template <typename dataT>
510+
extern const __attribute__((opencl_local)) dataT *
511+
__SYCL_GenericCastToPtr_ToLocal(const void *Ptr) noexcept {
512+
return (const __attribute__((opencl_local)) dataT *)
513+
__spirv_GenericCastToPtr_ToLocal(Ptr, __spv::StorageClass::Workgroup);
514+
}
515+
516+
template <typename dataT>
517+
extern volatile __attribute__((opencl_local)) dataT *
518+
__SYCL_GenericCastToPtr_ToLocal(volatile void *Ptr) noexcept {
519+
return (volatile __attribute__((opencl_local)) dataT *)
520+
__spirv_GenericCastToPtr_ToLocal(Ptr, __spv::StorageClass::Workgroup);
521+
}
522+
523+
template <typename dataT>
524+
extern const volatile __attribute__((opencl_local)) dataT *
525+
__SYCL_GenericCastToPtr_ToLocal(const volatile void *Ptr) noexcept {
526+
return (const volatile __attribute__((opencl_local)) dataT *)
527+
__spirv_GenericCastToPtr_ToLocal(Ptr, __spv::StorageClass::Workgroup);
528+
}
529+
530+
template <typename dataT>
531+
extern __attribute__((opencl_private)) dataT *
532+
__SYCL_GenericCastToPtr_ToPrivate(void *Ptr) noexcept {
533+
return (__attribute__((opencl_private)) dataT *)
534+
__spirv_GenericCastToPtr_ToPrivate(Ptr, __spv::StorageClass::Function);
535+
}
536+
537+
template <typename dataT>
538+
extern const __attribute__((opencl_private)) dataT *
539+
__SYCL_GenericCastToPtr_ToPrivate(const void *Ptr) noexcept {
540+
return (const __attribute__((opencl_private)) dataT *)
541+
__spirv_GenericCastToPtr_ToPrivate(Ptr, __spv::StorageClass::Function);
542+
}
543+
544+
template <typename dataT>
545+
extern volatile __attribute__((opencl_private)) dataT *
546+
__SYCL_GenericCastToPtr_ToPrivate(volatile void *Ptr) noexcept {
547+
return (volatile __attribute__((opencl_private)) dataT *)
548+
__spirv_GenericCastToPtr_ToPrivate(Ptr, __spv::StorageClass::Function);
549+
}
550+
551+
template <typename dataT>
552+
extern const volatile __attribute__((opencl_private)) dataT *
553+
__SYCL_GenericCastToPtr_ToPrivate(const volatile void *Ptr) noexcept {
554+
return (const volatile __attribute__((opencl_private)) dataT *)
555+
__spirv_GenericCastToPtr_ToPrivate(Ptr, __spv::StorageClass::Function);
556+
}
557+
470558
template <typename dataT>
471559
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL dataT
472560
__spirv_SubgroupShuffleINTEL(dataT Data, uint32_t InvocationId) noexcept;

sycl/include/sycl/detail/spirv.hpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,6 +1176,30 @@ __SYCL_GROUP_COLLECTIVE_OVERLOAD(BitwiseAndKHR)
11761176
__SYCL_GROUP_COLLECTIVE_OVERLOAD(LogicalAndKHR)
11771177
__SYCL_GROUP_COLLECTIVE_OVERLOAD(LogicalOrKHR)
11781178

1179+
template <access::address_space Space, typename T>
1180+
auto GenericCastToPtr(T *Ptr) ->
1181+
typename multi_ptr<T, Space, access::decorated::yes>::pointer {
1182+
if constexpr (Space == access::address_space::global_space) {
1183+
return __SYCL_GenericCastToPtr_ToGlobal<T>(Ptr);
1184+
} else if constexpr (Space == access::address_space::local_space) {
1185+
return __SYCL_GenericCastToPtr_ToLocal<T>(Ptr);
1186+
} else if constexpr (Space == access::address_space::private_space) {
1187+
return __SYCL_GenericCastToPtr_ToPrivate<T>(Ptr);
1188+
}
1189+
}
1190+
1191+
template <access::address_space Space, typename T>
1192+
auto GenericCastToPtrExplicit(T *Ptr) ->
1193+
typename multi_ptr<T, Space, access::decorated::yes>::pointer {
1194+
if constexpr (Space == access::address_space::global_space) {
1195+
return __SYCL_GenericCastToPtrExplicit_ToGlobal<T>(Ptr);
1196+
} else if constexpr (Space == access::address_space::local_space) {
1197+
return __SYCL_GenericCastToPtrExplicit_ToLocal<T>(Ptr);
1198+
} else if constexpr (Space == access::address_space::private_space) {
1199+
return __SYCL_GenericCastToPtrExplicit_ToPrivate<T>(Ptr);
1200+
}
1201+
}
1202+
11791203
} // namespace spirv
11801204
} // namespace detail
11811205
} // namespace _V1
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
//==----------- address_cast.hpp - sycl_ext_oneapi_address_cast ------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#pragma once
10+
#include <sycl/multi_ptr.hpp>
11+
12+
namespace sycl {
13+
inline namespace _V1 {
14+
namespace ext {
15+
namespace oneapi {
16+
namespace experimental {
17+
18+
template <access::address_space Space, access::decorated DecorateAddress,
19+
typename ElementType>
20+
multi_ptr<ElementType, Space, DecorateAddress>
21+
static_address_cast(ElementType *Ptr) {
22+
#ifdef __SYCL_DEVICE_ONLY__
23+
auto CastPtr = sycl::detail::spirv::GenericCastToPtr<Space>(Ptr);
24+
return multi_ptr<ElementType, Space, DecorateAddress>(CastPtr);
25+
#else
26+
return multi_ptr<ElementType, Space, DecorateAddress>(Ptr);
27+
#endif
28+
}
29+
30+
template <access::address_space Space, access::decorated DecorateAddress,
31+
typename ElementType>
32+
multi_ptr<ElementType, Space, DecorateAddress>
33+
dynamic_address_cast(ElementType *Ptr) {
34+
#ifdef __SYCL_DEVICE_ONLY__
35+
auto CastPtr = sycl::detail::spirv::GenericCastToPtrExplicit<Space>(Ptr);
36+
return multi_ptr<ElementType, Space, DecorateAddress>(CastPtr);
37+
#else
38+
return multi_ptr<ElementType, Space, DecorateAddress>(Ptr);
39+
#endif
40+
}
41+
42+
} // namespace experimental
43+
} // namespace oneapi
44+
} // namespace ext
45+
} // namespace _V1
46+
} // namespace sycl

sycl/include/sycl/sycl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
#include <sycl/ext/oneapi/bindless_images.hpp>
7575
#include <sycl/ext/oneapi/device_global/device_global.hpp>
7676
#include <sycl/ext/oneapi/device_global/properties.hpp>
77+
#include <sycl/ext/oneapi/experimental/address_cast.hpp>
7778
#include <sycl/ext/oneapi/experimental/annotated_arg/annotated_arg.hpp>
7879
#include <sycl/ext/oneapi/experimental/annotated_ptr/annotated_ptr.hpp>
7980
#include <sycl/ext/oneapi/experimental/annotated_usm/alloc_device.hpp>
@@ -100,7 +101,6 @@
100101
#include <sycl/ext/oneapi/sub_group.hpp>
101102
#include <sycl/ext/oneapi/sub_group_mask.hpp>
102103
#include <sycl/ext/oneapi/weak_object.hpp>
103-
104104
#if !defined(SYCL2020_CONFORMANT_APIS) && \
105105
!defined(__INTEL_PREVIEW_BREAKING_CHANGES)
106106
// We used to include those and some code might be reliant on that.
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
//==--------- dynamic_address_cast.cpp - dynamic address_cast test ---------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
// Issue with OpenCL CPU runtime implementation of OpGenericCastToPtrExplicit
10+
// UNSUPPORTED: cpu
11+
// RUN: %{build} -o %t.out
12+
// RUN: %{run} %t.out
13+
#include <sycl/sycl.hpp>
14+
15+
int main() {
16+
17+
sycl::queue Queue;
18+
19+
sycl::range<1> NItems{1};
20+
21+
sycl::buffer<int, 1> GlobalBuffer{NItems};
22+
sycl::buffer<bool, 1> ResultBuffer{NItems};
23+
24+
Queue
25+
.submit([&](sycl::handler &cgh) {
26+
auto GlobalAccessor =
27+
GlobalBuffer.get_access<sycl::access::mode::read_write>(cgh);
28+
auto LocalAccessor = sycl::local_accessor<int>(1, cgh);
29+
auto ResultAccessor =
30+
ResultBuffer.get_access<sycl::access::mode::write>(cgh);
31+
cgh.parallel_for<class Kernel>(
32+
sycl::nd_range<1>(NItems, 1), [=](sycl::nd_item<1> Item) {
33+
bool Success = true;
34+
size_t Index = Item.get_global_id(0);
35+
36+
int *RawGlobalPointer = &GlobalAccessor[Index];
37+
{
38+
auto GlobalPointer =
39+
sycl::ext::oneapi::experimental::dynamic_address_cast<
40+
sycl::access::address_space::global_space,
41+
sycl::access::decorated::no>(RawGlobalPointer);
42+
auto LocalPointer =
43+
sycl::ext::oneapi::experimental::dynamic_address_cast<
44+
sycl::access::address_space::local_space,
45+
sycl::access::decorated::no>(RawGlobalPointer);
46+
auto PrivatePointer =
47+
sycl::ext::oneapi::experimental::dynamic_address_cast<
48+
sycl::access::address_space::private_space,
49+
sycl::access::decorated::no>(RawGlobalPointer);
50+
Success &= reinterpret_cast<size_t>(RawGlobalPointer) ==
51+
reinterpret_cast<size_t>(GlobalPointer.get_raw());
52+
Success &= LocalPointer.get_raw() == nullptr;
53+
Success &= PrivatePointer.get_raw() == nullptr;
54+
}
55+
56+
int *RawLocalPointer = &LocalAccessor[0];
57+
{
58+
auto GlobalPointer =
59+
sycl::ext::oneapi::experimental::dynamic_address_cast<
60+
sycl::access::address_space::global_space,
61+
sycl::access::decorated::no>(RawLocalPointer);
62+
auto LocalPointer =
63+
sycl::ext::oneapi::experimental::dynamic_address_cast<
64+
sycl::access::address_space::local_space,
65+
sycl::access::decorated::no>(RawLocalPointer);
66+
auto PrivatePointer =
67+
sycl::ext::oneapi::experimental::dynamic_address_cast<
68+
sycl::access::address_space::private_space,
69+
sycl::access::decorated::no>(RawLocalPointer);
70+
Success &= GlobalPointer.get_raw() == nullptr;
71+
Success &= reinterpret_cast<size_t>(RawLocalPointer) ==
72+
reinterpret_cast<size_t>(LocalPointer.get_raw());
73+
Success &= PrivatePointer.get_raw() == nullptr;
74+
}
75+
76+
int PrivateVariable = 0;
77+
int *RawPrivatePointer = &PrivateVariable;
78+
{
79+
auto GlobalPointer =
80+
sycl::ext::oneapi::experimental::dynamic_address_cast<
81+
sycl::access::address_space::global_space,
82+
sycl::access::decorated::no>(RawPrivatePointer);
83+
auto LocalPointer =
84+
sycl::ext::oneapi::experimental::dynamic_address_cast<
85+
sycl::access::address_space::local_space,
86+
sycl::access::decorated::no>(RawPrivatePointer);
87+
auto PrivatePointer =
88+
sycl::ext::oneapi::experimental::dynamic_address_cast<
89+
sycl::access::address_space::private_space,
90+
sycl::access::decorated::no>(RawPrivatePointer);
91+
Success &= GlobalPointer.get_raw() == nullptr;
92+
Success &= LocalPointer.get_raw() == nullptr;
93+
Success &= reinterpret_cast<size_t>(RawPrivatePointer) ==
94+
reinterpret_cast<size_t>(PrivatePointer.get_raw());
95+
}
96+
97+
ResultAccessor[Index] = Success;
98+
});
99+
})
100+
.wait();
101+
102+
bool Success = true;
103+
{
104+
auto ResultAccessor = sycl::host_accessor(ResultBuffer);
105+
for (int i = 0; i < NItems.size(); ++i) {
106+
Success &= ResultAccessor[i];
107+
};
108+
}
109+
110+
return (Success) ? 0 : -1;
111+
}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
//==---------- static_address_cast.cpp - static address_cast test ----------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
// RUN: %{build} -o %t.out
10+
// RUN: %{run} %t.out
11+
#include <sycl/sycl.hpp>
12+
13+
int main() {
14+
15+
sycl::queue Queue;
16+
17+
sycl::range<1> NItems{1};
18+
19+
sycl::buffer<int, 1> GlobalBuffer{NItems};
20+
sycl::buffer<bool, 1> ResultBuffer{NItems};
21+
22+
Queue
23+
.submit([&](sycl::handler &cgh) {
24+
auto GlobalAccessor =
25+
GlobalBuffer.get_access<sycl::access::mode::read_write>(cgh);
26+
auto LocalAccessor = sycl::local_accessor<int>(1, cgh);
27+
auto ResultAccessor =
28+
ResultBuffer.get_access<sycl::access::mode::write>(cgh);
29+
cgh.parallel_for<class Kernel>(
30+
sycl::nd_range<1>(NItems, 1), [=](sycl::nd_item<1> Item) {
31+
bool Success = true;
32+
size_t Index = Item.get_global_id(0);
33+
34+
int *RawGlobalPointer = &GlobalAccessor[Index];
35+
auto GlobalPointer =
36+
sycl::ext::oneapi::experimental::static_address_cast<
37+
sycl::access::address_space::global_space,
38+
sycl::access::decorated::no>(RawGlobalPointer);
39+
Success &= reinterpret_cast<size_t>(RawGlobalPointer) ==
40+
reinterpret_cast<size_t>(GlobalPointer.get_raw());
41+
42+
int *RawLocalPointer = &LocalAccessor[0];
43+
auto LocalPointer =
44+
sycl::ext::oneapi::experimental::static_address_cast<
45+
sycl::access::address_space::local_space,
46+
sycl::access::decorated::no>(RawLocalPointer);
47+
Success &= reinterpret_cast<size_t>(RawLocalPointer) ==
48+
reinterpret_cast<size_t>(LocalPointer.get_raw());
49+
50+
int PrivateVariable = 0;
51+
int *RawPrivatePointer = &PrivateVariable;
52+
auto PrivatePointer =
53+
sycl::ext::oneapi::experimental::static_address_cast<
54+
sycl::access::address_space::private_space,
55+
sycl::access::decorated::no>(RawPrivatePointer);
56+
Success &= reinterpret_cast<size_t>(RawPrivatePointer) ==
57+
reinterpret_cast<size_t>(PrivatePointer.get_raw());
58+
59+
ResultAccessor[Index] = Success;
60+
});
61+
})
62+
.wait();
63+
64+
bool Success = true;
65+
{
66+
auto ResultAccessor = sycl::host_accessor(ResultBuffer);
67+
for (int i = 0; i < NItems.size(); ++i) {
68+
Success &= ResultAccessor[i];
69+
};
70+
}
71+
72+
return (Success) ? 0 : -1;
73+
}

0 commit comments

Comments
 (0)