Skip to content

Commit 37a9a2a

Browse files
author
Artem Gindinson
authored
[SYCL] Specialize atomic fetch_add for floating point types (#2765)
The new EXT/SPV_EXT_shader_atomic_float_add SPIR-V extension allows us to further specialize atomic::fetch_add() for floating point types. In device mode, we'll now be creating an external call to a built-in-like __spirv_AtomicFAddEXT(). This is similar to what is done for other atomic binary instructions, e.g. the integer specialization of fetch_add() being mapped onto __spirv_AtomicIAdd(). Furthermore, atomic::fetch_sub() is also re-implemented to use __spirv_AtomicFAddEXT(), the added operand being a negation of the original one. The new implementation can be exposed if a dedicated macro is defined: SYCL_USE_NATIVE_FP_ATOMICS. Otherwise, a fallback is used, where the atomic operation is done via spinlock emulation. At the moment of committing this, only Intel GPUs support the "native" implementation, which relies on a SPIR-V extension. Tests for the feature have been finalized in intel/llvm-test-suite#104. Signed-off-by: Artem Gindinson [email protected]
1 parent a4d4cfb commit 37a9a2a

File tree

5 files changed

+82
-27
lines changed

5 files changed

+82
-27
lines changed

sycl/include/CL/__spirv/spirv_ops.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ extern SYCL_EXTERNAL TempRetT __spirv_ImageSampleExplicitLod(SampledType,
7979
extern SYCL_EXTERNAL Type __spirv_AtomicISub( \
8080
AS Type *P, __spv::Scope::Flag S, __spv::MemorySemanticsMask::Flag O, \
8181
Type V);
82+
#define __SPIRV_ATOMIC_FADD(AS, Type) \
83+
extern SYCL_EXTERNAL Type __spirv_AtomicFAddEXT( \
84+
AS Type *P, __spv::Scope::Flag S, __spv::MemorySemanticsMask::Flag O, \
85+
Type V);
8286
#define __SPIRV_ATOMIC_SMIN(AS, Type) \
8387
extern SYCL_EXTERNAL Type __spirv_AtomicSMin( \
8488
AS Type *P, __spv::Scope::Flag S, __spv::MemorySemanticsMask::Flag O, \
@@ -109,6 +113,7 @@ extern SYCL_EXTERNAL TempRetT __spirv_ImageSampleExplicitLod(SampledType,
109113
Type V);
110114

111115
#define __SPIRV_ATOMIC_FLOAT(AS, Type) \
116+
__SPIRV_ATOMIC_FADD(AS, Type) \
112117
__SPIRV_ATOMIC_LOAD(AS, Type) \
113118
__SPIRV_ATOMIC_STORE(AS, Type) \
114119
__SPIRV_ATOMIC_EXCHANGE(AS, Type)

sycl/include/CL/sycl/ONEAPI/atomic_ref.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,11 @@ class atomic_ref_impl<
453453

454454
T fetch_add(T operand, memory_order order = default_read_modify_write_order,
455455
memory_scope scope = default_scope) const noexcept {
456+
// TODO: Remove the "native atomics" macro check once implemented for all
457+
// backends
458+
#if defined(__SYCL_DEVICE_ONLY__) && defined(SYCL_USE_NATIVE_FP_ATOMICS)
459+
return detail::spirv::AtomicFAdd(ptr, scope, order, operand);
460+
#else
456461
auto load_order = detail::getLoadOrder(order);
457462
T expected;
458463
T desired;
@@ -462,6 +467,7 @@ class atomic_ref_impl<
462467
desired = expected + operand;
463468
} while (!compare_exchange_weak(expected, desired, order, scope));
464469
return expected;
470+
#endif
465471
}
466472

467473
T operator+=(T operand) const noexcept {
@@ -470,13 +476,19 @@ class atomic_ref_impl<
470476

471477
T fetch_sub(T operand, memory_order order = default_read_modify_write_order,
472478
memory_scope scope = default_scope) const noexcept {
479+
// TODO: Remove the "native atomics" macro check once implemented for all
480+
// backends
481+
#if defined(__SYCL_DEVICE_ONLY__) && defined(SYCL_USE_NATIVE_FP_ATOMICS)
482+
return detail::spirv::AtomicFAdd(ptr, scope, order, -operand);
483+
#else
473484
auto load_order = detail::getLoadOrder(order);
474485
T expected = load(load_order, scope);
475486
T desired;
476487
do {
477488
desired = expected - operand;
478489
} while (!compare_exchange_weak(expected, desired, order, scope));
479490
return expected;
491+
#endif
480492
}
481493

482494
T operator-=(T operand) const noexcept {

sycl/include/CL/sycl/detail/spirv.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,16 @@ AtomicISub(multi_ptr<T, AddressSpace> MPtr, ONEAPI::memory_scope Scope,
385385
return __spirv_AtomicISub(Ptr, SPIRVScope, SPIRVOrder, Value);
386386
}
387387

388+
template <typename T, access::address_space AddressSpace>
389+
inline typename detail::enable_if_t<std::is_floating_point<T>::value, T>
390+
AtomicFAdd(multi_ptr<T, AddressSpace> MPtr, ONEAPI::memory_scope Scope,
391+
ONEAPI::memory_order Order, T Value) {
392+
auto *Ptr = MPtr.get();
393+
auto SPIRVOrder = getMemorySemanticsMask(Order);
394+
auto SPIRVScope = getScope(Scope);
395+
return __spirv_AtomicFAddEXT(Ptr, SPIRVScope, SPIRVOrder, Value);
396+
}
397+
388398
template <typename T, access::address_space AddressSpace>
389399
inline typename detail::enable_if_t<std::is_integral<T>::value, T>
390400
AtomicAnd(multi_ptr<T, AddressSpace> MPtr, ONEAPI::memory_scope Scope,

sycl/test/atomic_ref/add.cpp

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
// TODO: Once NVPTX accepts the __spirv_AtomicF*() IR, remove the XFAIL mark
2+
// XFAIL: cuda
3+
4+
// RUN: %clangxx -fsycl -fsycl-unnamed-lambda -DSYCL_USE_NATIVE_FP_ATOMICS \
5+
// RUN: -fsycl-device-only -S %s -o - | FileCheck %s --check-prefix=CHECK-LLVM
16
// RUN: %clangxx -fsycl -fsycl-unnamed-lambda -fsycl-device-only -S %s -o - \
2-
// RUN: | FileCheck %s --check-prefix=CHECK-LLVM
7+
// RUN: | FileCheck %s --check-prefix=CHECK-LLVM-EMU
38
// RUN: %clangxx -fsycl -fsycl-unnamed-lambda -fsycl-targets=%sycl_triple %s -o %t.out
49
// RUN: %RUN_ON_HOST %t.out
510

@@ -167,22 +172,24 @@ void add_test(queue q, size_t N) {
167172
// Floating-point types do not support pre- or post-increment
168173
template <> void add_test<float>(queue q, size_t N) {
169174
add_fetch_test<float>(q, N);
170-
// CHECK-LLVM: declare dso_local spir_func i32
171-
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicLoad
172-
// CHECK-LLVM-SAME: (i32 addrspace(1)*, i32, i32)
173-
// CHECK-LLVM: declare dso_local spir_func i32
174-
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicCompareExchange
175-
// CHECK-LLVM-SAME: (i32 addrspace(1)*, i32, i32, i32, i32, i32)
175+
// CHECK-LLVM: declare dso_local spir_func float
176+
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicFAddEXT
177+
// CHECK-LLVM-SAME: (float addrspace(1)*, i32, i32, float)
178+
// CHECK-LLVM-EMU: declare {{.*}} i32 @{{.*}}__spirv_AtomicLoad
179+
// CHECK-LLVM-EMU-SAME: (i32 addrspace(1)*, i32, i32)
180+
// CHECK-LLVM-EMU: declare {{.*}} i32 @{{.*}}__spirv_AtomicCompareExchange
181+
// CHECK-LLVM-EMU-SAME: (i32 addrspace(1)*, i32, i32, i32, i32, i32)
176182
add_plus_equal_test<float>(q, N);
177183
}
178184
template <> void add_test<double>(queue q, size_t N) {
179185
add_fetch_test<double>(q, N);
180-
// CHECK-LLVM: declare dso_local spir_func i64
181-
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicLoad
182-
// CHECK-LLVM-SAME: (i64 addrspace(1)*, i32, i32)
183-
// CHECK-LLVM: declare dso_local spir_func i64
184-
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicCompareExchange
185-
// CHECK-LLVM-SAME: (i64 addrspace(1)*, i32, i32, i32, i64, i64)
186+
// CHECK-LLVM: declare dso_local spir_func double
187+
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicFAddEXT
188+
// CHECK-LLVM-SAME: double addrspace(1)*, i32, i32, double)
189+
// CHECK-LLVM-EMU: declare {{.*}} i64 @{{.*}}__spirv_AtomicLoad
190+
// CHECK-LLVM-EMU-SAME: (i64 addrspace(1)*, i32, i32)
191+
// CHECK-LLVM-EMU: declare {{.*}} i64 @{{.*}}__spirv_AtomicCompareExchange
192+
// CHECK-LLVM-EMU-SAME: (i64 addrspace(1)*, i32, i32, i32, i64, i64)
186193
add_plus_equal_test<double>(q, N);
187194
}
188195

@@ -219,9 +226,15 @@ int main() {
219226
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicIAdd
220227
// CHECK-LLVM-SAME: (i64 addrspace(1)*, i32, i32, i64)
221228
add_test<unsigned long long>(q, N);
222-
// The remaining functions have been instantiated earlier
229+
// Floating point-typed functions have been instantiated earlier
223230
add_test<float>(q, N);
224231
add_test<double>(q, N);
232+
// CHECK-LLVM: declare dso_local spir_func i64
233+
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicLoad
234+
// CHECK-LLVM-SAME: i64 addrspace(1)*, i32, i32)
235+
// CHECK-LLVM: declare dso_local spir_func i64
236+
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicCompareExchange
237+
// CHECK-LLVM-SAME: i64 addrspace(1)*, i32, i32, i32, i64, i64)
225238
add_test<char *, ptrdiff_t>(q, N);
226239

227240
std::cout << "Test passed." << std::endl;

sycl/test/atomic_ref/sub.cpp

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
// TODO: Once NVPTX accepts the __spirv_AtomicF*() IR, remove the XFAIL mark
2+
// XFAIL: cuda
3+
4+
// RUN: %clangxx -fsycl -fsycl-unnamed-lambda -DSYCL_USE_NATIVE_FP_ATOMICS \
5+
// RUN: -fsycl-device-only -S %s -o - | FileCheck %s --check-prefix=CHECK-LLVM
6+
// RUN: %clangxx -fsycl -fsycl-unnamed-lambda -fsycl-device-only -S %s -o - \
7+
// RUN: | FileCheck %s --check-prefix=CHECK-LLVM-EMU
18
// RUN: %clangxx -fsycl -fsycl-unnamed-lambda -fsycl-targets=%sycl_triple %s -o %t.out
29
// RUN: %RUN_ON_HOST %t.out
310

@@ -165,22 +172,24 @@ void sub_test(queue q, size_t N) {
165172
// Floating-point types do not support pre- or post-decrement
166173
template <> void sub_test<float>(queue q, size_t N) {
167174
sub_fetch_test<float>(q, N);
168-
// CHECK-LLVM: declare dso_local spir_func i32
169-
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicLoad
170-
// CHECK-LLVM-SAME: (i32 addrspace(1)*, i32, i32)
171-
// CHECK-LLVM: declare dso_local spir_func i32
172-
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicCompareExchange
173-
// CHECK-LLVM-SAME: (i32 addrspace(1)*, i32, i32, i32, i32, i32)
175+
// CHECK-LLVM: declare dso_local spir_func float
176+
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicFAddEXT
177+
// CHECK-LLVM-SAME: (float addrspace(1)*, i32, i32, float)
178+
// CHECK-LLVM-EMU: declare {{.*}} i32 @{{.*}}__spirv_AtomicLoad
179+
// CHECK-LLVM-EMU-SAME: (i32 addrspace(1)*, i32, i32)
180+
// CHECK-LLVM-EMU: declare {{.*}} i32 @{{.*}}__spirv_AtomicCompareExchange
181+
// CHECK-LLVM-EMU-SAME: (i32 addrspace(1)*, i32, i32, i32, i32, i32)
174182
sub_plus_equal_test<float>(q, N);
175183
}
176184
template <> void sub_test<double>(queue q, size_t N) {
177185
sub_fetch_test<double>(q, N);
178-
// CHECK-LLVM: declare dso_local spir_func i64
179-
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicLoad
180-
// CHECK-LLVM-SAME: (i64 addrspace(1)*, i32, i32)
181-
// CHECK-LLVM: declare dso_local spir_func i64
182-
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicCompareExchange
183-
// CHECK-LLVM-SAME: (i64 addrspace(1)*, i32, i32, i32, i64, i64)
186+
// CHECK-LLVM: declare dso_local spir_func double
187+
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicFAddEXT
188+
// CHECK-LLVM-SAME: (double addrspace(1)*, i32, i32, double)
189+
// CHECK-LLVM-EMU: declare {{.*}} i64 @{{.*}}__spirv_AtomicLoad
190+
// CHECK-LLVM-EMU-SAME: (i64 addrspace(1)*, i32, i32)
191+
// CHECK-LLVM-EMU: declare {{.*}} i64 @{{.*}}__spirv_AtomicCompareExchange
192+
// CHECK-LLVM-EMU-SAME: (i64 addrspace(1)*, i32, i32, i32, i64, i64)
184193
sub_plus_equal_test<double>(q, N);
185194
}
186195

@@ -217,9 +226,15 @@ int main() {
217226
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicISub
218227
// CHECK-LLVM-SAME: (i64 addrspace(1)*, i32, i32, i64)
219228
sub_test<unsigned long long>(q, N);
220-
// The remaining functions have been instantiated earlier
229+
// Floating point-typed functions have been instantiated earlier
221230
sub_test<float>(q, N);
222231
sub_test<double>(q, N);
232+
// CHECK-LLVM: declare dso_local spir_func i64
233+
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicLoad
234+
// CHECK-LLVM-SAME: (i64 addrspace(1)*, i32, i32)
235+
// CHECK-LLVM: declare dso_local spir_func i64
236+
// CHECK-LLVM-SAME: @_Z{{[0-9]+}}__spirv_AtomicCompareExchange
237+
// CHECK-LLVM-SAME: (i64 addrspace(1)*, i32, i32, i32, i64, i64)
223238
sub_test<char *, ptrdiff_t>(q, N);
224239

225240
std::cout << "Test passed." << std::endl;

0 commit comments

Comments
 (0)