Skip to content

Commit a3c3425

Browse files
authored
[SYCL] Add prototype of atomic_ref<T*> (#2177)
Enables partial specialization of atomic_ref for pointer types. Implementation assumes that both host and device pointers can be stored in a uintptr_t, but uses compare_exchange to implement pointer arithmetic rather than make assumptions about how pointers will be represented on different devices. Signed-off-by: John Pennycook <[email protected]>
1 parent 8ac87a3 commit a3c3425

File tree

10 files changed

+204
-97
lines changed

10 files changed

+204
-97
lines changed

sycl/doc/extensions/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ DPC++ extensions status:
1313
| [SYCL_INTEL_deduction_guides](deduction_guides/SYCL_INTEL_deduction_guides.asciidoc) | Supported | |
1414
| [SYCL_INTEL_device_specific_kernel_queries](DeviceSpecificKernelQueries/SYCL_INTEL_device_specific_kernel_queries.asciidoc) | Proposal | |
1515
| [SYCL_INTEL_enqueue_barrier](EnqueueBarrier/enqueue_barrier.asciidoc) | Supported(OpenCL, Level Zero) | |
16-
| [SYCL_INTEL_extended_atomics](ExtendedAtomics/SYCL_INTEL_extended_atomics.asciidoc) | Partially supported(OpenCL: CPU, GPU) | Not supported: pointer types |
16+
| [SYCL_INTEL_extended_atomics](ExtendedAtomics/SYCL_INTEL_extended_atomics.asciidoc) | Supported(OpenCL: CPU, GPU) | |
1717
| [SYCL_INTEL_group_algorithms](GroupAlgorithms/SYCL_INTEL_group_algorithms.asciidoc) | Supported(OpenCL) | |
1818
| [SYCL_INTEL_group_mask](./GroupMask/SYCL_INTEL_group_mask.asciidoc) | Proposal | |
1919
| [FPGA selector](IntelFPGA/FPGASelector.md) | Supported | |

sycl/include/CL/sycl/intel/atomic_ref.hpp

Lines changed: 130 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,6 @@ class atomic_ref_base {
135135
static_assert(!(std::is_same<T, short>::value ||
136136
std::is_same<T, unsigned short>::value),
137137
"intel::atomic_ref does not support short type");
138-
static_assert(!std::is_pointer<T>::value,
139-
"intel::atomic_ref does not yet support pointer types");
140138
static_assert(detail::IsValidAtomicAddressSpace<AddressSpace>::value,
141139
"Invalid atomic address_space. Valid address spaces are: "
142140
"global_space, local_space, global_device_space");
@@ -508,12 +506,138 @@ class atomic_ref_impl<
508506
};
509507

510508
// Partial specialization for pointer types
509+
// Arithmetic is emulated because target's representation of T* is unknown
510+
// TODO: Find a way to use intptr_t or uintptr_t atomics instead
511511
template <typename T, memory_order DefaultOrder, memory_scope DefaultScope,
512512
access::address_space AddressSpace>
513-
class atomic_ref_impl<T *, DefaultOrder, DefaultScope, AddressSpace,
514-
typename detail::enable_if_t<std::is_pointer<T>::value>>
515-
: public atomic_ref_base<T *, DefaultOrder, DefaultScope, AddressSpace> {
516-
// TODO: Implement partial specialization for pointer types
513+
class atomic_ref_impl<T *, DefaultOrder, DefaultScope, AddressSpace>
514+
: public atomic_ref_base<uintptr_t, DefaultOrder, DefaultScope,
515+
AddressSpace> {
516+
517+
private:
518+
using base_type =
519+
atomic_ref_base<uintptr_t, DefaultOrder, DefaultScope, AddressSpace>;
520+
521+
public:
522+
using value_type = T *;
523+
using difference_type = ptrdiff_t;
524+
static constexpr size_t required_alignment = sizeof(T *);
525+
static constexpr bool is_always_lock_free =
526+
detail::IsValidAtomicType<T>::value;
527+
static constexpr memory_order default_read_order =
528+
detail::memory_order_traits<DefaultOrder>::read_order;
529+
static constexpr memory_order default_write_order =
530+
detail::memory_order_traits<DefaultOrder>::write_order;
531+
static constexpr memory_order default_read_modify_write_order = DefaultOrder;
532+
static constexpr memory_scope default_scope = DefaultScope;
533+
534+
using base_type::is_lock_free;
535+
536+
atomic_ref_impl(T *&ref) : base_type(reinterpret_cast<uintptr_t &>(ref)) {}
537+
538+
void store(T *operand, memory_order order = default_write_order,
539+
memory_scope scope = default_scope) const noexcept {
540+
base_type::store(reinterpret_cast<uintptr_t>(operand), order, scope);
541+
}
542+
543+
T *operator=(T *desired) const noexcept {
544+
store(desired);
545+
return desired;
546+
}
547+
548+
T *load(memory_order order = default_read_order,
549+
memory_scope scope = default_scope) const noexcept {
550+
return reinterpret_cast<T *>(base_type::load(order, scope));
551+
}
552+
553+
operator T *() const noexcept { return load(); }
554+
555+
T *exchange(T *operand, memory_order order = default_read_modify_write_order,
556+
memory_scope scope = default_scope) const noexcept {
557+
return reinterpret_cast<T *>(base_type::exchange(
558+
reinterpret_cast<uintptr_t>(operand), order, scope));
559+
}
560+
561+
T *fetch_add(difference_type operand,
562+
memory_order order = default_read_modify_write_order,
563+
memory_scope scope = default_scope) const noexcept {
564+
// TODO: Find a way to avoid compare_exchange here
565+
auto load_order = detail::getLoadOrder(order);
566+
T *expected = load(load_order, scope);
567+
T *desired;
568+
do {
569+
desired = expected + operand;
570+
} while (!compare_exchange_weak(expected, desired, order, scope));
571+
return expected;
572+
}
573+
574+
T *operator+=(difference_type operand) const noexcept {
575+
return fetch_add(operand) + operand;
576+
}
577+
578+
T *operator++(int) const noexcept { return fetch_add(difference_type(1)); }
579+
580+
T *operator++() const noexcept {
581+
return fetch_add(difference_type(1)) + difference_type(1);
582+
}
583+
584+
T *fetch_sub(difference_type operand,
585+
memory_order order = default_read_modify_write_order,
586+
memory_scope scope = default_scope) const noexcept {
587+
// TODO: Find a way to avoid compare_exchange here
588+
auto load_order = detail::getLoadOrder(order);
589+
T *expected = load(load_order, scope);
590+
T *desired;
591+
do {
592+
desired = expected - operand;
593+
} while (!compare_exchange_weak(expected, desired, order, scope));
594+
return expected;
595+
}
596+
597+
T *operator-=(difference_type operand) const noexcept {
598+
return fetch_sub(operand) - operand;
599+
}
600+
601+
T *operator--(int) const noexcept { return fetch_sub(difference_type(1)); }
602+
603+
T *operator--() const noexcept {
604+
return fetch_sub(difference_type(1)) - difference_type(1);
605+
}
606+
607+
bool
608+
compare_exchange_strong(T *&expected, T *desired, memory_order success,
609+
memory_order failure,
610+
memory_scope scope = default_scope) const noexcept {
611+
return base_type::compare_exchange_strong(
612+
reinterpret_cast<uintptr_t &>(expected),
613+
reinterpret_cast<uintptr_t>(desired), success, failure, scope);
614+
}
615+
616+
bool
617+
compare_exchange_strong(T *&expected, T *desired,
618+
memory_order order = default_read_modify_write_order,
619+
memory_scope scope = default_scope) const noexcept {
620+
return compare_exchange_strong(expected, desired, order, order, scope);
621+
}
622+
623+
bool
624+
compare_exchange_weak(T *&expected, T *desired, memory_order success,
625+
memory_order failure,
626+
memory_scope scope = default_scope) const noexcept {
627+
return base_type::compare_exchange_weak(
628+
reinterpret_cast<uintptr_t &>(expected),
629+
reinterpret_cast<uintptr_t>(desired), success, failure, scope);
630+
}
631+
632+
bool
633+
compare_exchange_weak(T *&expected, T *desired,
634+
memory_order order = default_read_modify_write_order,
635+
memory_scope scope = default_scope) const noexcept {
636+
return compare_exchange_weak(expected, desired, order, order, scope);
637+
}
638+
639+
private:
640+
using base_type::ptr;
517641
};
518642

519643
} // namespace detail

sycl/test/atomic_ref/add.cpp

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
using namespace sycl;
1313
using namespace sycl::intel;
1414

15-
template <typename T>
15+
template <typename T, typename Difference = T>
1616
void add_fetch_test(queue q, size_t N) {
1717
T sum = 0;
1818
std::vector<T> output(N);
19-
std::fill(output.begin(), output.end(), 0);
19+
std::fill(output.begin(), output.end(), T(0));
2020
{
2121
buffer<T> sum_buf(&sum, 1);
2222
buffer<T> output_buf(output.data(), output.size());
@@ -27,29 +27,29 @@ void add_fetch_test(queue q, size_t N) {
2727
cgh.parallel_for(range<1>(N), [=](item<1> it) {
2828
int gid = it.get_id(0);
2929
auto atm = atomic_ref<T, intel::memory_order::relaxed, intel::memory_scope::device, access::address_space::global_space>(sum[0]);
30-
out[gid] = atm.fetch_add(T(1));
30+
out[gid] = atm.fetch_add(Difference(1));
3131
});
3232
});
3333
}
3434

3535
// All work-items increment by 1, so final value should be equal to N
36-
assert(sum == N);
36+
assert(sum == T(N));
3737

3838
// Fetch returns original value: will be in [0, N-1]
3939
auto min_e = std::min_element(output.begin(), output.end());
4040
auto max_e = std::max_element(output.begin(), output.end());
41-
assert(*min_e == 0 && *max_e == N - 1);
41+
assert(*min_e == T(0) && *max_e == T(N - 1));
4242

4343
// Intermediate values should be unique
4444
std::sort(output.begin(), output.end());
4545
assert(std::unique(output.begin(), output.end()) == output.end());
4646
}
4747

48-
template <typename T>
48+
template <typename T, typename Difference = T>
4949
void add_plus_equal_test(queue q, size_t N) {
5050
T sum = 0;
5151
std::vector<T> output(N);
52-
std::fill(output.begin(), output.end(), 0);
52+
std::fill(output.begin(), output.end(), T(0));
5353
{
5454
buffer<T> sum_buf(&sum, 1);
5555
buffer<T> output_buf(output.data(), output.size());
@@ -60,29 +60,29 @@ void add_plus_equal_test(queue q, size_t N) {
6060
cgh.parallel_for(range<1>(N), [=](item<1> it) {
6161
int gid = it.get_id(0);
6262
auto atm = atomic_ref<T, intel::memory_order::relaxed, intel::memory_scope::device, access::address_space::global_space>(sum[0]);
63-
out[gid] = atm += T(1);
63+
out[gid] = atm += Difference(1);
6464
});
6565
});
6666
}
6767

6868
// All work-items increment by 1, so final value should be equal to N
69-
assert(sum == N);
69+
assert(sum == T(N));
7070

7171
// += returns updated value: will be in [1, N]
7272
auto min_e = std::min_element(output.begin(), output.end());
7373
auto max_e = std::max_element(output.begin(), output.end());
74-
assert(*min_e == 1 && *max_e == N);
74+
assert(*min_e == T(1) && *max_e == T(N));
7575

7676
// Intermediate values should be unique
7777
std::sort(output.begin(), output.end());
7878
assert(std::unique(output.begin(), output.end()) == output.end());
7979
}
8080

81-
template <typename T>
81+
template <typename T, typename Difference = T>
8282
void add_pre_inc_test(queue q, size_t N) {
8383
T sum = 0;
8484
std::vector<T> output(N);
85-
std::fill(output.begin(), output.end(), 0);
85+
std::fill(output.begin(), output.end(), T(0));
8686
{
8787
buffer<T> sum_buf(&sum, 1);
8888
buffer<T> output_buf(output.data(), output.size());
@@ -99,23 +99,23 @@ void add_pre_inc_test(queue q, size_t N) {
9999
}
100100

101101
// All work-items increment by 1, so final value should be equal to N
102-
assert(sum == N);
102+
assert(sum == T(N));
103103

104104
// Pre-increment returns updated value: will be in [1, N]
105105
auto min_e = std::min_element(output.begin(), output.end());
106106
auto max_e = std::max_element(output.begin(), output.end());
107-
assert(*min_e == 1 && *max_e == N);
107+
assert(*min_e == T(1) && *max_e == T(N));
108108

109109
// Intermediate values should be unique
110110
std::sort(output.begin(), output.end());
111111
assert(std::unique(output.begin(), output.end()) == output.end());
112112
}
113113

114-
template <typename T>
114+
template <typename T, typename Difference = T>
115115
void add_post_inc_test(queue q, size_t N) {
116116
T sum = 0;
117117
std::vector<T> output(N);
118-
std::fill(output.begin(), output.end(), 0);
118+
std::fill(output.begin(), output.end(), T(0));
119119
{
120120
buffer<T> sum_buf(&sum, 1);
121121
buffer<T> output_buf(output.data(), output.size());
@@ -132,24 +132,24 @@ void add_post_inc_test(queue q, size_t N) {
132132
}
133133

134134
// All work-items increment by 1, so final value should be equal to N
135-
assert(sum == N);
135+
assert(sum == T(N));
136136

137137
// Post-increment returns original value: will be in [0, N-1]
138138
auto min_e = std::min_element(output.begin(), output.end());
139139
auto max_e = std::max_element(output.begin(), output.end());
140-
assert(*min_e == 0 && *max_e == N - 1);
140+
assert(*min_e == T(0) && *max_e == T(N - 1));
141141

142142
// Intermediate values should be unique
143143
std::sort(output.begin(), output.end());
144144
assert(std::unique(output.begin(), output.end()) == output.end());
145145
}
146146

147-
template <typename T>
147+
template <typename T, typename Difference = T>
148148
void add_test(queue q, size_t N) {
149-
add_fetch_test<T>(q, N);
150-
add_plus_equal_test<T>(q, N);
151-
add_pre_inc_test<T>(q, N);
152-
add_post_inc_test<T>(q, N);
149+
add_fetch_test<T, Difference>(q, N);
150+
add_plus_equal_test<T, Difference>(q, N);
151+
add_pre_inc_test<T, Difference>(q, N);
152+
add_post_inc_test<T, Difference>(q, N);
153153
}
154154

155155
// Floating-point types do not support pre- or post-increment
@@ -173,8 +173,6 @@ int main() {
173173
}
174174

175175
constexpr int N = 32;
176-
177-
// TODO: Enable missing tests when supported
178176
add_test<int>(q, N);
179177
add_test<unsigned int>(q, N);
180178
add_test<long>(q, N);
@@ -183,7 +181,7 @@ int main() {
183181
add_test<unsigned long long>(q, N);
184182
add_test<float>(q, N);
185183
add_test<double>(q, N);
186-
//add_test<char*>(q, N);
184+
add_test<char *, ptrdiff_t>(q, N);
187185

188186
std::cout << "Test passed." << std::endl;
189187
}

sycl/test/atomic_ref/compare_exchange.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,27 @@ class compare_exchange_kernel;
1616

1717
template <typename T>
1818
void compare_exchange_test(queue q, size_t N) {
19-
const T initial = std::numeric_limits<T>::max();
19+
const T initial = T(N);
2020
T compare_exchange = initial;
2121
std::vector<T> output(N);
22-
std::fill(output.begin(), output.end(), 0);
22+
std::fill(output.begin(), output.end(), T(0));
2323
{
2424
buffer<T> compare_exchange_buf(&compare_exchange, 1);
2525
buffer<T> output_buf(output.data(), output.size());
2626

2727
q.submit([&](handler &cgh) {
2828
auto exc = compare_exchange_buf.template get_access<access::mode::read_write>(cgh);
2929
auto out = output_buf.template get_access<access::mode::discard_write>(cgh);
30-
cgh.parallel_for<compare_exchange_kernel<T>>(range<1>(N), [=](item<1> it) {
31-
int gid = it.get_id(0);
30+
cgh.parallel_for<compare_exchange_kernel<T>>(range<1>(N), [=](item<1>
31+
it) {
32+
size_t gid = it.get_id(0);
3233
auto atm = atomic_ref<T, intel::memory_order::relaxed, intel::memory_scope::device, access::address_space::global_space>(exc[0]);
33-
T result = initial;
34+
T result = T(N); // Avoid copying pointer
3435
bool success = atm.compare_exchange_strong(result, (T)gid);
3536
if (success) {
3637
out[gid] = result;
3738
} else {
38-
out[gid] = gid;
39+
out[gid] = T(gid);
3940
}
4041
});
4142
});
@@ -45,7 +46,7 @@ void compare_exchange_test(queue q, size_t N) {
4546
assert(std::count(output.begin(), output.end(), initial) == 1);
4647

4748
// All other values should be the index itself or the sentinel value
48-
for (int i = 0; i < N; ++i) {
49+
for (size_t i = 0; i < N; ++i) {
4950
assert(output[i] == T(i) || output[i] == initial);
5051
}
5152
}
@@ -59,8 +60,6 @@ int main() {
5960
}
6061

6162
constexpr int N = 32;
62-
63-
// TODO: Enable missing tests when supported
6463
compare_exchange_test<int>(q, N);
6564
compare_exchange_test<unsigned int>(q, N);
6665
compare_exchange_test<long>(q, N);
@@ -69,7 +68,7 @@ int main() {
6968
compare_exchange_test<unsigned long long>(q, N);
7069
compare_exchange_test<float>(q, N);
7170
compare_exchange_test<double>(q, N);
72-
//compare_exchange_test<char*>(q, N);
71+
compare_exchange_test<char *>(q, N);
7372

7473
std::cout << "Test passed." << std::endl;
7574
}

0 commit comments

Comments
 (0)