Skip to content

Commit 4f91bbb

Browse files
authored
[SYCL] Fix multi_ptr relational operators customised for nullptr. (#13201)
`multi_ptr` relational operators taking a `std::nullptr_t` are written in a way that assume it is the lowest possible value (example https://github.com/intel/llvm/blob/sycl/sycl/include/sycl/multi_ptr.hpp#L1575). However the C++ specs states there is no ordering requirement (https://eel.is/c++draft/expr.rel#4.3). In practice, this is causing issues in the CUDA and AMDGPU backend. For instance, in CUDA the nullptr in the local address space is a non `0` value and `0` in this address space is the root of the allocated local memory.
1 parent 0f6148a commit 4f91bbb

File tree

2 files changed

+105
-12
lines changed

2 files changed

+105
-12
lines changed

sycl/include/sycl/multi_ptr.hpp

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1560,56 +1560,64 @@ template <typename ElementType, access::address_space Space,
15601560
access::decorated DecorateAddress>
15611561
bool operator>(const multi_ptr<ElementType, Space, DecorateAddress> &lhs,
15621562
std::nullptr_t) {
1563-
return lhs.get() != nullptr;
1563+
return lhs.get() >
1564+
multi_ptr<ElementType, Space, DecorateAddress>(nullptr).get();
15641565
}
15651566

15661567
template <typename ElementType, access::address_space Space,
15671568
access::decorated DecorateAddress>
15681569
bool operator>(std::nullptr_t,
1569-
const multi_ptr<ElementType, Space, DecorateAddress> &) {
1570-
return false;
1570+
const multi_ptr<ElementType, Space, DecorateAddress> &rhs) {
1571+
return multi_ptr<ElementType, Space, DecorateAddress>(nullptr).get() >
1572+
rhs.get();
15711573
}
15721574

15731575
template <typename ElementType, access::address_space Space,
15741576
access::decorated DecorateAddress>
1575-
bool operator<(const multi_ptr<ElementType, Space, DecorateAddress> &,
1577+
bool operator<(const multi_ptr<ElementType, Space, DecorateAddress> &lhs,
15761578
std::nullptr_t) {
1577-
return false;
1579+
return lhs.get() <
1580+
multi_ptr<ElementType, Space, DecorateAddress>(nullptr).get();
15781581
}
15791582

15801583
template <typename ElementType, access::address_space Space,
15811584
access::decorated DecorateAddress>
15821585
bool operator<(std::nullptr_t,
15831586
const multi_ptr<ElementType, Space, DecorateAddress> &rhs) {
1584-
return rhs.get() != nullptr;
1587+
return multi_ptr<ElementType, Space, DecorateAddress>(nullptr).get() <
1588+
rhs.get();
15851589
}
15861590

15871591
template <typename ElementType, access::address_space Space,
15881592
access::decorated DecorateAddress>
1589-
bool operator>=(const multi_ptr<ElementType, Space, DecorateAddress> &,
1593+
bool operator>=(const multi_ptr<ElementType, Space, DecorateAddress> &lhs,
15901594
std::nullptr_t) {
1591-
return true;
1595+
return lhs.get() >=
1596+
multi_ptr<ElementType, Space, DecorateAddress>(nullptr).get();
15921597
}
15931598

15941599
template <typename ElementType, access::address_space Space,
15951600
access::decorated DecorateAddress>
15961601
bool operator>=(std::nullptr_t,
15971602
const multi_ptr<ElementType, Space, DecorateAddress> &rhs) {
1598-
return rhs.get() == nullptr;
1603+
return multi_ptr<ElementType, Space, DecorateAddress>(nullptr).get() >=
1604+
rhs.get();
15991605
}
16001606

16011607
template <typename ElementType, access::address_space Space,
16021608
access::decorated DecorateAddress>
16031609
bool operator<=(const multi_ptr<ElementType, Space, DecorateAddress> &lhs,
16041610
std::nullptr_t) {
1605-
return lhs.get() == nullptr;
1611+
return lhs.get() <=
1612+
multi_ptr<ElementType, Space, DecorateAddress>(nullptr).get();
16061613
}
16071614

16081615
template <typename ElementType, access::address_space Space,
16091616
access::decorated DecorateAddress>
16101617
bool operator<=(std::nullptr_t,
1611-
const multi_ptr<ElementType, Space, DecorateAddress> &) {
1612-
return true;
1618+
const multi_ptr<ElementType, Space, DecorateAddress> &rhs) {
1619+
return multi_ptr<ElementType, Space, DecorateAddress>(nullptr).get() <=
1620+
rhs.get();
16131621
}
16141622

16151623
} // namespace _V1
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// RUN: %{build} -o %t.out
2+
// RUN: %{run} %t.out
3+
4+
#include <sycl/detail/core.hpp>
5+
#include <sycl/multi_ptr.hpp>
6+
7+
template <typename T, typename AccessorType,
8+
sycl::access::address_space address_space,
9+
sycl::access::decorated decorated>
10+
void check(sycl::multi_ptr<T, address_space, decorated> &mp,
11+
AccessorType &dev_acc) {
12+
using multi_ptr_t = sycl::multi_ptr<T, address_space, decorated>;
13+
multi_ptr_t null_mp;
14+
dev_acc[0] = nullptr == null_mp;
15+
dev_acc[0] += nullptr != mp;
16+
dev_acc[0] += std::less<multi_ptr_t>()(nullptr, mp) == nullptr < mp;
17+
dev_acc[0] += std::less<multi_ptr_t>()(mp, nullptr) == mp < nullptr;
18+
dev_acc[0] += std::less_equal<multi_ptr_t>()(nullptr, mp) == nullptr <= mp;
19+
dev_acc[0] += std::less_equal<multi_ptr_t>()(mp, nullptr) == mp <= nullptr;
20+
dev_acc[0] += std::greater<multi_ptr_t>()(nullptr, mp) == nullptr > mp;
21+
dev_acc[0] += std::greater<multi_ptr_t>()(mp, nullptr) == mp > nullptr;
22+
dev_acc[0] += std::greater_equal<multi_ptr_t>()(nullptr, mp) == nullptr >= mp;
23+
dev_acc[0] += std::greater_equal<multi_ptr_t>()(mp, nullptr) == mp >= nullptr;
24+
}
25+
26+
template <typename T, sycl::access::address_space address_space,
27+
sycl::access::decorated decorated>
28+
void nullptrRelationalOperatorTest() {
29+
using multi_ptr_t = sycl::multi_ptr<int, address_space, decorated>;
30+
try {
31+
sycl::queue queue;
32+
sycl::buffer<int, 1> buf(1);
33+
queue
34+
.submit([&](sycl::handler &cgh) {
35+
auto dev_acc = buf.get_access<sycl::access::mode::write>(cgh);
36+
if constexpr (address_space ==
37+
sycl::access::address_space::local_space) {
38+
sycl::local_accessor<int, 1> locAcc(1, cgh);
39+
cgh.parallel_for(sycl::nd_range<1>{1, 1}, [=](sycl::id<1>) {
40+
locAcc[0] = 1;
41+
multi_ptr_t mp(locAcc);
42+
check(mp, dev_acc);
43+
});
44+
} else if constexpr (address_space ==
45+
sycl::access::address_space::private_space) {
46+
cgh.single_task([=] {
47+
T priv_arr[1];
48+
sycl::multi_ptr<T, address_space, decorated> mp =
49+
sycl::address_space_cast<address_space, decorated>(priv_arr);
50+
check(mp, dev_acc);
51+
});
52+
} else {
53+
cgh.single_task([=] {
54+
multi_ptr_t mp(dev_acc);
55+
check(mp, dev_acc);
56+
});
57+
}
58+
})
59+
.wait_and_throw();
60+
assert(sycl::host_accessor{buf}[0] == 10);
61+
} catch (sycl::exception e) {
62+
std::cout << "SYCL exception caught: " << e.what();
63+
return;
64+
}
65+
}
66+
67+
int main() {
68+
nullptrRelationalOperatorTest<int, sycl::access::address_space::local_space,
69+
sycl::access::decorated::yes>();
70+
nullptrRelationalOperatorTest<int, sycl::access::address_space::local_space,
71+
sycl::access::decorated::no>();
72+
nullptrRelationalOperatorTest<int, sycl::access::address_space::global_space,
73+
sycl::access::decorated::yes>();
74+
nullptrRelationalOperatorTest<int, sycl::access::address_space::global_space,
75+
sycl::access::decorated::no>();
76+
nullptrRelationalOperatorTest<int, sycl::access::address_space::generic_space,
77+
sycl::access::decorated::yes>();
78+
nullptrRelationalOperatorTest<int, sycl::access::address_space::generic_space,
79+
sycl::access::decorated::no>();
80+
nullptrRelationalOperatorTest<int, sycl::access::address_space::private_space,
81+
sycl::access::decorated::yes>();
82+
nullptrRelationalOperatorTest<int, sycl::access::address_space::private_space,
83+
sycl::access::decorated::no>();
84+
return 0;
85+
}

0 commit comments

Comments
 (0)