Skip to content

Commit a9ca820

Browse files
authored
[libc] Use clang's scoped atomics if available from the compiler (#74769)
Summary: A recent patch in #72280 provided `clang` the ability to easily use scoped atomics. These are a special modifier on atomics that some backends support. They are intended for providing more fine-grained control over the affected memory of an atomic action. The default is a "system" scope, e.g. coherence with the GPU and CPU memory on a heterogeneous system. If we use "device" scope, that implies that the memory is only ordered with respect to the current GPU. These builtins are direct replacements for the GCC atomic builitins in cases where the backend doesn't do anything with the information, so these should be a drop-in. This introduces some noise, but hopefully it isn't too contentious.
1 parent ecd4781 commit a9ca820

File tree

2 files changed

+75
-24
lines changed

2 files changed

+75
-24
lines changed

libc/src/__support/CPP/atomic.h

Lines changed: 64 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,18 @@ enum class MemoryOrder : int {
2626
SEQ_CST = __ATOMIC_SEQ_CST
2727
};
2828

29+
// These are a clang extension, see the clang documenation for more information:
30+
// https://clang.llvm.org/docs/LanguageExtensions.html#scoped-atomic-builtins.
31+
enum class MemoryScope : int {
32+
#if defined(__MEMORY_SCOPE_SYSTEM) && defined(__MEMORY_SCOPE_DEVICE)
33+
SYSTEM = __MEMORY_SCOPE_SYSTEM,
34+
DEVICE = __MEMORY_SCOPE_DEVICE,
35+
#else
36+
SYSTEM = 0,
37+
DEVICE = 0,
38+
#endif
39+
};
40+
2941
template <typename T> struct Atomic {
3042
// For now, we will restrict to only arithmetic types.
3143
static_assert(is_arithmetic_v<T>, "Only arithmetic types can be atomic.");
@@ -54,48 +66,82 @@ template <typename T> struct Atomic {
5466
Atomic(const Atomic &) = delete;
5567
Atomic &operator=(const Atomic &) = delete;
5668

57-
// Atomic load
69+
// Atomic load.
5870
operator T() { return __atomic_load_n(&val, int(MemoryOrder::SEQ_CST)); }
5971

60-
T load(MemoryOrder mem_ord = MemoryOrder::SEQ_CST) {
61-
return __atomic_load_n(&val, int(mem_ord));
72+
T load(MemoryOrder mem_ord = MemoryOrder::SEQ_CST,
73+
[[maybe_unused]] MemoryScope mem_scope = MemoryScope::DEVICE) {
74+
if constexpr (LIBC_HAS_BUILTIN(__scoped_atomic_load_n))
75+
return __scoped_atomic_load_n(&val, int(mem_ord), (int)(mem_scope));
76+
else
77+
return __atomic_load_n(&val, int(mem_ord));
6278
}
6379

64-
// Atomic store
80+
// Atomic store.
6581
T operator=(T rhs) {
6682
__atomic_store_n(&val, rhs, int(MemoryOrder::SEQ_CST));
6783
return rhs;
6884
}
6985

70-
void store(T rhs, MemoryOrder mem_ord = MemoryOrder::SEQ_CST) {
71-
__atomic_store_n(&val, rhs, int(mem_ord));
86+
void store(T rhs, MemoryOrder mem_ord = MemoryOrder::SEQ_CST,
87+
[[maybe_unused]] MemoryScope mem_scope = MemoryScope::DEVICE) {
88+
if constexpr (LIBC_HAS_BUILTIN(__scoped_atomic_store_n))
89+
__scoped_atomic_store_n(&val, rhs, int(mem_ord), (int)(mem_scope));
90+
else
91+
__atomic_store_n(&val, rhs, int(mem_ord));
7292
}
7393

7494
// Atomic compare exchange
75-
bool compare_exchange_strong(T &expected, T desired,
76-
MemoryOrder mem_ord = MemoryOrder::SEQ_CST) {
95+
bool compare_exchange_strong(
96+
T &expected, T desired, MemoryOrder mem_ord = MemoryOrder::SEQ_CST,
97+
[[maybe_unused]] MemoryScope mem_scope = MemoryScope::DEVICE) {
7798
return __atomic_compare_exchange_n(&val, &expected, desired, false,
7899
int(mem_ord), int(mem_ord));
79100
}
80101

81-
T exchange(T desired, MemoryOrder mem_ord = MemoryOrder::SEQ_CST) {
82-
return __atomic_exchange_n(&val, desired, int(mem_ord));
102+
T exchange(T desired, MemoryOrder mem_ord = MemoryOrder::SEQ_CST,
103+
[[maybe_unused]] MemoryScope mem_scope = MemoryScope::DEVICE) {
104+
if constexpr (LIBC_HAS_BUILTIN(__scoped_atomic_exchange_n))
105+
return __scoped_atomic_exchange_n(&val, desired, int(mem_ord),
106+
(int)(mem_scope));
107+
else
108+
return __atomic_exchange_n(&val, desired, int(mem_ord));
83109
}
84110

85-
T fetch_add(T increment, MemoryOrder mem_ord = MemoryOrder::SEQ_CST) {
86-
return __atomic_fetch_add(&val, increment, int(mem_ord));
111+
T fetch_add(T increment, MemoryOrder mem_ord = MemoryOrder::SEQ_CST,
112+
[[maybe_unused]] MemoryScope mem_scope = MemoryScope::DEVICE) {
113+
if constexpr (LIBC_HAS_BUILTIN(__scoped_atomic_fetch_add))
114+
return __scoped_atomic_fetch_add(&val, increment, int(mem_ord),
115+
(int)(mem_scope));
116+
else
117+
return __atomic_fetch_add(&val, increment, int(mem_ord));
87118
}
88119

89-
T fetch_or(T mask, MemoryOrder mem_ord = MemoryOrder::SEQ_CST) {
90-
return __atomic_fetch_or(&val, mask, int(mem_ord));
120+
T fetch_or(T mask, MemoryOrder mem_ord = MemoryOrder::SEQ_CST,
121+
[[maybe_unused]] MemoryScope mem_scope = MemoryScope::DEVICE) {
122+
if constexpr (LIBC_HAS_BUILTIN(__scoped_atomic_fetch_or))
123+
return __scoped_atomic_fetch_or(&val, mask, int(mem_ord),
124+
(int)(mem_scope));
125+
else
126+
return __atomic_fetch_or(&val, mask, int(mem_ord));
91127
}
92128

93-
T fetch_and(T mask, MemoryOrder mem_ord = MemoryOrder::SEQ_CST) {
94-
return __atomic_fetch_and(&val, mask, int(mem_ord));
129+
T fetch_and(T mask, MemoryOrder mem_ord = MemoryOrder::SEQ_CST,
130+
[[maybe_unused]] MemoryScope mem_scope = MemoryScope::DEVICE) {
131+
if constexpr (LIBC_HAS_BUILTIN(__scoped_atomic_fetch_and))
132+
return __scoped_atomic_fetch_and(&val, mask, int(mem_ord),
133+
(int)(mem_scope));
134+
else
135+
return __atomic_fetch_and(&val, mask, int(mem_ord));
95136
}
96137

97-
T fetch_sub(T decrement, MemoryOrder mem_ord = MemoryOrder::SEQ_CST) {
98-
return __atomic_fetch_sub(&val, decrement, int(mem_ord));
138+
T fetch_sub(T decrement, MemoryOrder mem_ord = MemoryOrder::SEQ_CST,
139+
[[maybe_unused]] MemoryScope mem_scope = MemoryScope::DEVICE) {
140+
if constexpr (LIBC_HAS_BUILTIN(__scoped_atomic_fetch_sub))
141+
return __scoped_atomic_fetch_sub(&val, decrement, int(mem_ord),
142+
(int)(mem_scope));
143+
else
144+
return __atomic_fetch_sub(&val, decrement, int(mem_ord));
99145
}
100146

101147
// Set the value without using an atomic operation. This is useful

libc/src/__support/RPC/rpc.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,16 @@ template <bool Invert, typename Packet> struct Process {
109109

110110
/// Retrieve the inbox state from memory shared between processes.
111111
LIBC_INLINE uint32_t load_inbox(uint64_t lane_mask, uint32_t index) const {
112-
return gpu::broadcast_value(lane_mask,
113-
inbox[index].load(cpp::MemoryOrder::RELAXED));
112+
return gpu::broadcast_value(
113+
lane_mask,
114+
inbox[index].load(cpp::MemoryOrder::RELAXED, cpp::MemoryScope::SYSTEM));
114115
}
115116

116117
/// Retrieve the outbox state from memory shared between processes.
117118
LIBC_INLINE uint32_t load_outbox(uint64_t lane_mask, uint32_t index) const {
118119
return gpu::broadcast_value(lane_mask,
119-
outbox[index].load(cpp::MemoryOrder::RELAXED));
120+
outbox[index].load(cpp::MemoryOrder::RELAXED,
121+
cpp::MemoryScope::SYSTEM));
120122
}
121123

122124
/// Signal to the other process that this one is finished with the buffer.
@@ -126,7 +128,8 @@ template <bool Invert, typename Packet> struct Process {
126128
LIBC_INLINE uint32_t invert_outbox(uint32_t index, uint32_t current_outbox) {
127129
uint32_t inverted_outbox = !current_outbox;
128130
atomic_thread_fence(cpp::MemoryOrder::RELEASE);
129-
outbox[index].store(inverted_outbox, cpp::MemoryOrder::RELAXED);
131+
outbox[index].store(inverted_outbox, cpp::MemoryOrder::RELAXED,
132+
cpp::MemoryScope::SYSTEM);
130133
return inverted_outbox;
131134
}
132135

@@ -241,7 +244,8 @@ template <bool Invert, typename Packet> struct Process {
241244
uint32_t slot = index / NUM_BITS_IN_WORD;
242245
uint32_t bit = index % NUM_BITS_IN_WORD;
243246
return bits[slot].fetch_or(static_cast<uint32_t>(cond) << bit,
244-
cpp::MemoryOrder::RELAXED) &
247+
cpp::MemoryOrder::RELAXED,
248+
cpp::MemoryScope::DEVICE) &
245249
(1u << bit);
246250
}
247251

@@ -251,7 +255,8 @@ template <bool Invert, typename Packet> struct Process {
251255
uint32_t slot = index / NUM_BITS_IN_WORD;
252256
uint32_t bit = index % NUM_BITS_IN_WORD;
253257
return bits[slot].fetch_and(~0u ^ (static_cast<uint32_t>(cond) << bit),
254-
cpp::MemoryOrder::RELAXED) &
258+
cpp::MemoryOrder::RELAXED,
259+
cpp::MemoryScope::DEVICE) &
255260
(1u << bit);
256261
}
257262
};

0 commit comments

Comments
 (0)