Skip to content

Commit a4c53e4

Browse files
[SYCL] Align sycl_ext_oneapi_address_cast impl with the spec (#15402)
Reflects spec changes from #12689
1 parent aa2748d commit a4c53e4

File tree

5 files changed

+113
-84
lines changed

5 files changed

+113
-84
lines changed

sycl/doc/extensions/proposed/sycl_ext_oneapi_address_cast.asciidoc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,10 @@ implementation supports.
9494
namespace sycl::ext::oneapi::experimental {
9595
9696
// Shorthands for address space names
97-
constexpr inline address_space global_space = sycl::access::address_space::global_space;
98-
constexpr inline address_space local_space = sycl::access::address_space::local_space;
99-
constexpr inline address_space private_space = sycl::access::address_space::private_space;
100-
constexpr inline address_space generic_space = sycl::access::address_space::generic_space;
97+
constexpr inline access::address_space global_space = access::address_space::global_space;
98+
constexpr inline access::address_space local_space = access::address_space::local_space;
99+
constexpr inline access::address_space private_space = access::address_space::private_space;
100+
constexpr inline access::address_space generic_space = access::address_space::generic_space;
101101
102102
template <access::address_space Space,
103103
typename ElementType>

sycl/include/sycl/ext/oneapi/experimental/address_cast.hpp

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,49 +16,74 @@ inline namespace _V1 {
1616
namespace ext {
1717
namespace oneapi {
1818
namespace experimental {
19+
// Shorthands for address space names
20+
constexpr inline access::address_space global_space = access::address_space::global_space;
21+
constexpr inline access::address_space local_space = access::address_space::local_space;
22+
constexpr inline access::address_space private_space = access::address_space::private_space;
23+
constexpr inline access::address_space generic_space = access::address_space::generic_space;
1924

20-
template <access::address_space Space, access::decorated DecorateAddress,
21-
typename ElementType>
22-
multi_ptr<ElementType, Space, DecorateAddress>
25+
template <access::address_space Space, typename ElementType>
26+
multi_ptr<ElementType, Space, access::decorated::no>
2327
static_address_cast(ElementType *Ptr) {
28+
using ret_ty = multi_ptr<ElementType, Space, access::decorated::no>;
2429
#ifdef __SYCL_DEVICE_ONLY__
2530
// TODO: Remove this restriction.
2631
static_assert(std::is_same_v<ElementType, remove_decoration_t<ElementType>>,
2732
"The extension expect undecorated raw pointers only!");
28-
if constexpr (Space == access::address_space::generic_space) {
33+
if constexpr (Space == generic_space) {
2934
// Undecorated raw pointer is in generic AS already, no extra casts needed.
3035
// Note for future, for `OpPtrCastToGeneric`, `Pointer` must point to one of
3136
// `Storage Classes` that doesn't include `Generic`, so this will have to
3237
// remain a special case even if the restriction above is lifted.
33-
return multi_ptr<ElementType, Space, DecorateAddress>(Ptr);
38+
return ret_ty(Ptr);
3439
} else {
3540
auto CastPtr = sycl::detail::spirv::GenericCastToPtr<Space>(Ptr);
36-
return multi_ptr<ElementType, Space, DecorateAddress>(CastPtr);
41+
return ret_ty(CastPtr);
3742
}
3843
#else
39-
return multi_ptr<ElementType, Space, DecorateAddress>(Ptr);
44+
return ret_ty(Ptr);
4045
#endif
4146
}
4247

4348
template <access::address_space Space, access::decorated DecorateAddress,
4449
typename ElementType>
45-
multi_ptr<ElementType, Space, DecorateAddress>
50+
multi_ptr<ElementType, Space, DecorateAddress> static_address_cast(
51+
multi_ptr<ElementType, generic_space, DecorateAddress> Ptr) {
52+
if constexpr (Space == generic_space)
53+
return Ptr;
54+
else
55+
return {static_address_cast<Space>(Ptr.get_raw())};
56+
}
57+
58+
template <access::address_space Space, typename ElementType>
59+
multi_ptr<ElementType, Space, access::decorated::no>
4660
dynamic_address_cast(ElementType *Ptr) {
61+
using ret_ty = multi_ptr<ElementType, Space, access::decorated::no>;
4762
#ifdef __SYCL_DEVICE_ONLY__
4863
// TODO: Remove this restriction.
4964
static_assert(std::is_same_v<ElementType, remove_decoration_t<ElementType>>,
5065
"The extension expect undecorated raw pointers only!");
51-
if constexpr (Space == access::address_space::generic_space) {
52-
return multi_ptr<ElementType, Space, DecorateAddress>(Ptr);
66+
if constexpr (Space == generic_space) {
67+
return ret_ty(Ptr);
5368
} else {
5469
auto CastPtr = sycl::detail::spirv::GenericCastToPtrExplicit<Space>(Ptr);
55-
return multi_ptr<ElementType, Space, DecorateAddress>(CastPtr);
70+
return ret_ty(CastPtr);
5671
}
5772
#else
58-
return multi_ptr<ElementType, Space, DecorateAddress>(Ptr);
73+
return ret_ty(Ptr);
5974
#endif
6075
}
6176

77+
template <access::address_space Space, access::decorated DecorateAddress,
78+
typename ElementType>
79+
multi_ptr<ElementType, Space, DecorateAddress> dynamic_address_cast(
80+
multi_ptr<ElementType, generic_space, DecorateAddress> Ptr) {
81+
if constexpr (Space == generic_space)
82+
return Ptr;
83+
else
84+
return {dynamic_address_cast<Space>(Ptr.get_raw())};
85+
}
86+
6287
} // namespace experimental
6388
} // namespace oneapi
6489
} // namespace ext

sycl/test-e2e/AddressCast/dynamic_address_cast.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,16 @@ int main() {
4242
{
4343
auto GlobalPointer =
4444
sycl::ext::oneapi::experimental::dynamic_address_cast<
45-
sycl::access::address_space::global_space,
46-
sycl::access::decorated::no>(RawGlobalPointer);
45+
sycl::access::address_space::global_space>(
46+
RawGlobalPointer);
4747
auto LocalPointer =
4848
sycl::ext::oneapi::experimental::dynamic_address_cast<
49-
sycl::access::address_space::local_space,
50-
sycl::access::decorated::no>(RawGlobalPointer);
49+
sycl::access::address_space::local_space>(
50+
RawGlobalPointer);
5151
auto PrivatePointer =
5252
sycl::ext::oneapi::experimental::dynamic_address_cast<
53-
sycl::access::address_space::private_space,
54-
sycl::access::decorated::no>(RawGlobalPointer);
53+
sycl::access::address_space::private_space>(
54+
RawGlobalPointer);
5555
Success &= reinterpret_cast<size_t>(RawGlobalPointer) ==
5656
reinterpret_cast<size_t>(GlobalPointer.get_raw());
5757
Success &= LocalPointer.get_raw() == nullptr;
@@ -62,16 +62,16 @@ int main() {
6262
{
6363
auto GlobalPointer =
6464
sycl::ext::oneapi::experimental::dynamic_address_cast<
65-
sycl::access::address_space::global_space,
66-
sycl::access::decorated::no>(RawLocalPointer);
65+
sycl::access::address_space::global_space>(
66+
RawLocalPointer);
6767
auto LocalPointer =
6868
sycl::ext::oneapi::experimental::dynamic_address_cast<
69-
sycl::access::address_space::local_space,
70-
sycl::access::decorated::no>(RawLocalPointer);
69+
sycl::access::address_space::local_space>(
70+
RawLocalPointer);
7171
auto PrivatePointer =
7272
sycl::ext::oneapi::experimental::dynamic_address_cast<
73-
sycl::access::address_space::private_space,
74-
sycl::access::decorated::no>(RawLocalPointer);
73+
sycl::access::address_space::private_space>(
74+
RawLocalPointer);
7575
Success &= GlobalPointer.get_raw() == nullptr;
7676
Success &= reinterpret_cast<size_t>(RawLocalPointer) ==
7777
reinterpret_cast<size_t>(LocalPointer.get_raw());
@@ -83,16 +83,16 @@ int main() {
8383
{
8484
auto GlobalPointer =
8585
sycl::ext::oneapi::experimental::dynamic_address_cast<
86-
sycl::access::address_space::global_space,
87-
sycl::access::decorated::no>(RawPrivatePointer);
86+
sycl::access::address_space::global_space>(
87+
RawPrivatePointer);
8888
auto LocalPointer =
8989
sycl::ext::oneapi::experimental::dynamic_address_cast<
90-
sycl::access::address_space::local_space,
91-
sycl::access::decorated::no>(RawPrivatePointer);
90+
sycl::access::address_space::local_space>(
91+
RawPrivatePointer);
9292
auto PrivatePointer =
9393
sycl::ext::oneapi::experimental::dynamic_address_cast<
94-
sycl::access::address_space::private_space,
95-
sycl::access::decorated::no>(RawPrivatePointer);
94+
sycl::access::address_space::private_space>(
95+
RawPrivatePointer);
9696
Success &= GlobalPointer.get_raw() == nullptr;
9797
Success &= LocalPointer.get_raw() == nullptr;
9898
Success &= reinterpret_cast<size_t>(RawPrivatePointer) ==

sycl/test-e2e/AddressCast/static_address_cast.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,25 +39,25 @@ int main() {
3939
int *RawGlobalPointer = &GlobalAccessor[Index];
4040
auto GlobalPointer =
4141
sycl::ext::oneapi::experimental::static_address_cast<
42-
sycl::access::address_space::global_space,
43-
sycl::access::decorated::no>(RawGlobalPointer);
42+
sycl::access::address_space::global_space>(
43+
RawGlobalPointer);
4444
Success &= reinterpret_cast<size_t>(RawGlobalPointer) ==
4545
reinterpret_cast<size_t>(GlobalPointer.get_raw());
4646

4747
int *RawLocalPointer = &LocalAccessor[0];
4848
auto LocalPointer =
4949
sycl::ext::oneapi::experimental::static_address_cast<
50-
sycl::access::address_space::local_space,
51-
sycl::access::decorated::no>(RawLocalPointer);
50+
sycl::access::address_space::local_space>(
51+
RawLocalPointer);
5252
Success &= reinterpret_cast<size_t>(RawLocalPointer) ==
5353
reinterpret_cast<size_t>(LocalPointer.get_raw());
5454

5555
int PrivateVariable = 0;
5656
int *RawPrivatePointer = &PrivateVariable;
5757
auto PrivatePointer =
5858
sycl::ext::oneapi::experimental::static_address_cast<
59-
sycl::access::address_space::private_space,
60-
sycl::access::decorated::no>(RawPrivatePointer);
59+
sycl::access::address_space::private_space>(
60+
RawPrivatePointer);
6161
Success &= reinterpret_cast<size_t>(RawPrivatePointer) ==
6262
reinterpret_cast<size_t>(PrivatePointer.get_raw());
6363

0 commit comments

Comments
 (0)