Skip to content

Commit dc51240

Browse files
AlexeySachkovbader
authored andcommitted
[SYCL][NFC] Refactor common code into a helper function
Signed-off-by: Alexey Sachkov <[email protected]>
1 parent 4c58035 commit dc51240

File tree

6 files changed

+60
-81
lines changed

6 files changed

+60
-81
lines changed

sycl/include/CL/sycl/atomic.hpp

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
//===----------------------------------------------------------------------===//
88

99
#pragma once
10-
#include <CL/sycl/access/access.hpp>
11-
#ifdef __SYCL_DEVICE_ONLY__
10+
1211
#include <CL/__spirv/spirv_ops.hpp>
13-
#else
14-
#include <CL/__spirv/spirv_types.hpp>
12+
#include <CL/sycl/access/access.hpp>
13+
#include <CL/sycl/detail/helpers.hpp>
14+
15+
#ifndef __SYCL_DEVICE_ONLY__
1516
#include <atomic>
1617
#endif
1718
#include <type_traits>
@@ -57,12 +58,6 @@ template <> struct GetSpirvMemoryScope<access::address_space::local_space> {
5758
static constexpr auto scope = __spv::Scope::Workgroup;
5859
};
5960

60-
// Translate the cl::sycl::memory_order to a SPIR-V builtin order
61-
static inline __spv::MemorySemanticsMask
62-
getSpirvMemorySemanticsMask(memory_order Order) {
63-
return __spv::MemorySemanticsMask::None;
64-
}
65-
6661
} // namespace detail
6762
} // namespace sycl
6863
} // namespace cl
@@ -196,17 +191,17 @@ class atomic {
196191

197192
void store(T Operand, memory_order Order = memory_order::relaxed) {
198193
__spirv_AtomicStore(
199-
Ptr, SpirvScope, detail::getSpirvMemorySemanticsMask(Order), Operand);
194+
Ptr, SpirvScope, detail::getSPIRVMemorySemanticsMask(Order), Operand);
200195
}
201196

202197
T load(memory_order Order = memory_order::relaxed) const {
203198
return __spirv_AtomicLoad(Ptr, SpirvScope,
204-
detail::getSpirvMemorySemanticsMask(Order));
199+
detail::getSPIRVMemorySemanticsMask(Order));
205200
}
206201

207202
T exchange(T Operand, memory_order Order = memory_order::relaxed) {
208203
return __spirv_AtomicExchange(
209-
Ptr, SpirvScope, detail::getSpirvMemorySemanticsMask(Order), Operand);
204+
Ptr, SpirvScope, detail::getSPIRVMemorySemanticsMask(Order), Operand);
210205
}
211206

212207
bool
@@ -216,8 +211,8 @@ class atomic {
216211
STATIC_ASSERT_NOT_FLOAT(T);
217212
#ifdef __SYCL_DEVICE_ONLY__
218213
T Value = __spirv_AtomicCompareExchange(
219-
Ptr, SpirvScope, detail::getSpirvMemorySemanticsMask(SuccessOrder),
220-
detail::getSpirvMemorySemanticsMask(FailOrder), Desired, Expected);
214+
Ptr, SpirvScope, detail::getSPIRVMemorySemanticsMask(SuccessOrder),
215+
detail::getSPIRVMemorySemanticsMask(FailOrder), Desired, Expected);
221216
return (Value == Expected);
222217
#else
223218
return Ptr->compare_exchange_strong(Expected, Desired,
@@ -229,43 +224,43 @@ class atomic {
229224
T fetch_add(T Operand, memory_order Order = memory_order::relaxed) {
230225
STATIC_ASSERT_NOT_FLOAT(T);
231226
return __spirv_AtomicIAdd(
232-
Ptr, SpirvScope, detail::getSpirvMemorySemanticsMask(Order), Operand);
227+
Ptr, SpirvScope, detail::getSPIRVMemorySemanticsMask(Order), Operand);
233228
}
234229

235230
T fetch_sub(T Operand, memory_order Order = memory_order::relaxed) {
236231
STATIC_ASSERT_NOT_FLOAT(T);
237232
return __spirv_AtomicISub(
238-
Ptr, SpirvScope, detail::getSpirvMemorySemanticsMask(Order), Operand);
233+
Ptr, SpirvScope, detail::getSPIRVMemorySemanticsMask(Order), Operand);
239234
}
240235

241236
T fetch_and(T Operand, memory_order Order = memory_order::relaxed) {
242237
STATIC_ASSERT_NOT_FLOAT(T);
243238
return __spirv_AtomicAnd(
244-
Ptr, SpirvScope, detail::getSpirvMemorySemanticsMask(Order), Operand);
239+
Ptr, SpirvScope, detail::getSPIRVMemorySemanticsMask(Order), Operand);
245240
}
246241

247242
T fetch_or(T Operand, memory_order Order = memory_order::relaxed) {
248243
STATIC_ASSERT_NOT_FLOAT(T);
249244
return __spirv_AtomicOr(
250-
Ptr, SpirvScope, detail::getSpirvMemorySemanticsMask(Order), Operand);
245+
Ptr, SpirvScope, detail::getSPIRVMemorySemanticsMask(Order), Operand);
251246
}
252247

253248
T fetch_xor(T Operand, memory_order Order = memory_order::relaxed) {
254249
STATIC_ASSERT_NOT_FLOAT(T);
255250
return __spirv_AtomicXor(
256-
Ptr, SpirvScope, detail::getSpirvMemorySemanticsMask(Order), Operand);
251+
Ptr, SpirvScope, detail::getSPIRVMemorySemanticsMask(Order), Operand);
257252
}
258253

259254
T fetch_min(T Operand, memory_order Order = memory_order::relaxed) {
260255
STATIC_ASSERT_NOT_FLOAT(T);
261256
return __spirv_AtomicMin(
262-
Ptr, SpirvScope, detail::getSpirvMemorySemanticsMask(Order), Operand);
257+
Ptr, SpirvScope, detail::getSPIRVMemorySemanticsMask(Order), Operand);
263258
}
264259

265260
T fetch_max(T Operand, memory_order Order = memory_order::relaxed) {
266261
STATIC_ASSERT_NOT_FLOAT(T);
267262
return __spirv_AtomicMax(
268-
Ptr, SpirvScope, detail::getSpirvMemorySemanticsMask(Order), Operand);
263+
Ptr, SpirvScope, detail::getSPIRVMemorySemanticsMask(Order), Operand);
269264
}
270265

271266
private:

sycl/include/CL/sycl/detail/helpers.hpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#pragma once
1010

11+
#include <CL/__spirv/spirv_types.hpp>
12+
#include <CL/sycl/access/access.hpp>
1113
#include <CL/sycl/detail/common.hpp>
1214

1315
#include <memory>
@@ -24,6 +26,7 @@ template <int dimensions> class group;
2426
template <int dimensions> class range;
2527
template <int dimensions> struct id;
2628
template <int dimensions> class nd_item;
29+
enum class memory_order;
2730
namespace detail {
2831
class context_impl;
2932
// The function returns list of events that can be passed to OpenCL API as
@@ -68,6 +71,35 @@ struct Builder {
6871
}
6972
};
7073

74+
inline __spv::MemorySemanticsMask getSPIRVMemorySemanticsMask(memory_order) {
75+
return __spv::MemorySemanticsMask::None;
76+
}
77+
78+
inline uint32_t
79+
getSPIRVMemorySemanticsMask(access::fence_space AccessSpace,
80+
__spv::MemorySemanticsMask LocalScopeMask =
81+
__spv::MemorySemanticsMask::WorkgroupMemory) {
82+
uint32_t Flags =
83+
static_cast<uint32_t>(__spv::MemorySemanticsMask::SequentiallyConsistent);
84+
switch (AccessSpace) {
85+
case access::fence_space::global_space:
86+
Flags |=
87+
static_cast<uint32_t>(__spv::MemorySemanticsMask::CrossWorkgroupMemory);
88+
break;
89+
case access::fence_space::local_space:
90+
Flags |= static_cast<uint32_t>(LocalScopeMask);
91+
break;
92+
case access::fence_space::global_and_local:
93+
default:
94+
Flags |= static_cast<uint32_t>(
95+
__spv::MemorySemanticsMask::CrossWorkgroupMemory) |
96+
static_cast<uint32_t>(LocalScopeMask);
97+
break;
98+
}
99+
100+
return Flags;
101+
}
102+
71103
} // namespace detail
72104
} // namespace sycl
73105
} // namespace cl

sycl/include/CL/sycl/group.hpp

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#pragma once
1010

1111
#include <CL/__spirv/spirv_ops.hpp>
12+
#include <CL/sycl/detail/helpers.hpp>
1213
#include <CL/sycl/device_event.hpp>
1314
#include <CL/sycl/id.hpp>
1415
#include <CL/sycl/pointers.hpp>
@@ -81,25 +82,7 @@ template <int dimensions = 1> class group {
8182
accessMode == access::mode::read_write,
8283
access::fence_space>::type accessSpace =
8384
access::fence_space::global_and_local) const {
84-
uint32_t flags = static_cast<uint32_t>(
85-
__spv::MemorySemanticsMask::SequentiallyConsistent);
86-
switch (accessSpace) {
87-
case access::fence_space::global_space:
88-
flags |= static_cast<uint32_t>(
89-
__spv::MemorySemanticsMask::CrossWorkgroupMemory);
90-
break;
91-
case access::fence_space::local_space:
92-
flags |=
93-
static_cast<uint32_t>(__spv::MemorySemanticsMask::WorkgroupMemory);
94-
break;
95-
case access::fence_space::global_and_local:
96-
default:
97-
flags |=
98-
static_cast<uint32_t>(
99-
__spv::MemorySemanticsMask::CrossWorkgroupMemory) |
100-
static_cast<uint32_t>(__spv::MemorySemanticsMask::WorkgroupMemory);
101-
break;
102-
}
85+
uint32_t flags = detail::getSPIRVMemorySemanticsMask(accessSpace);
10386
// TODO: currently, there is no good way in SPIRV to set the memory
10487
// barrier only for load operations or only for store operations.
10588
// The full read-and-write barrier is used and the template parameter

sycl/include/CL/sycl/intel/sub_group.hpp

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <CL/__spirv/spirv_vars.hpp>
1212
#include <CL/sycl/access/access.hpp>
13+
#include <CL/sycl/detail/helpers.hpp>
1314
#include <CL/sycl/id.hpp>
1415
#include <CL/sycl/range.hpp>
1516
#include <CL/sycl/types.hpp>
@@ -334,25 +335,8 @@ struct sub_group {
334335
/* --- synchronization functions --- */
335336
void barrier(access::fence_space accessSpace =
336337
access::fence_space::global_and_local) const {
337-
uint32_t flags = static_cast<uint32_t>(
338-
__spv::MemorySemanticsMask::SequentiallyConsistent);
339-
switch (accessSpace) {
340-
case access::fence_space::global_space:
341-
flags |= static_cast<uint32_t>(
342-
__spv::MemorySemanticsMask::CrossWorkgroupMemory);
343-
break;
344-
case access::fence_space::local_space:
345-
flags |=
346-
static_cast<uint32_t>(__spv::MemorySemanticsMask::SubgroupMemory);
347-
break;
348-
case access::fence_space::global_and_local:
349-
default:
350-
flags |=
351-
static_cast<uint32_t>(
352-
__spv::MemorySemanticsMask::CrossWorkgroupMemory) |
353-
static_cast<uint32_t>(__spv::MemorySemanticsMask::SubgroupMemory);
354-
break;
355-
}
338+
uint32_t flags = detail::getSPIRVMemorySemanticsMask(
339+
accessSpace, __spv::MemorySemanticsMask::SubgroupMemory);
356340
__spirv_ControlBarrier(__spv::Scope::Subgroup, __spv::Scope::Workgroup,
357341
flags);
358342
}

sycl/include/CL/sycl/nd_item.hpp

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,16 @@
88

99
#pragma once
1010

11+
#include <CL/__spirv/spirv_ops.hpp>
1112
#include <CL/sycl/access/access.hpp>
13+
#include <CL/sycl/detail/helpers.hpp>
1214
#include <CL/sycl/group.hpp>
1315
#include <CL/sycl/id.hpp>
1416
#include <CL/sycl/intel/sub_group.hpp>
1517
#include <CL/sycl/item.hpp>
1618
#include <CL/sycl/nd_range.hpp>
1719
#include <CL/sycl/range.hpp>
18-
#include <CL/__spirv/spirv_ops.hpp>
20+
1921
#include <stdexcept>
2022
#include <type_traits>
2123

@@ -81,25 +83,7 @@ template <int dimensions = 1> struct nd_item {
8183

8284
void barrier(access::fence_space accessSpace =
8385
access::fence_space::global_and_local) const {
84-
uint32_t flags = static_cast<uint32_t>(
85-
__spv::MemorySemanticsMask::SequentiallyConsistent);
86-
switch (accessSpace) {
87-
case access::fence_space::global_space:
88-
flags |= static_cast<uint32_t>(
89-
__spv::MemorySemanticsMask::CrossWorkgroupMemory);
90-
break;
91-
case access::fence_space::local_space:
92-
flags |=
93-
static_cast<uint32_t>(__spv::MemorySemanticsMask::WorkgroupMemory);
94-
break;
95-
case access::fence_space::global_and_local:
96-
default:
97-
flags |=
98-
static_cast<uint32_t>(
99-
__spv::MemorySemanticsMask::CrossWorkgroupMemory) |
100-
static_cast<uint32_t>(__spv::MemorySemanticsMask::WorkgroupMemory);
101-
break;
102-
}
86+
uint32_t flags = detail::getSPIRVMemorySemanticsMask(accessSpace);
10387
__spirv_ControlBarrier(__spv::Scope::Workgroup, __spv::Scope::Workgroup,
10488
flags);
10589
}

sycl/source/detail/helpers.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include <CL/sycl/detail/context_impl.hpp>
109
#include <CL/sycl/detail/helpers.hpp>
10+
11+
#include <CL/sycl/detail/context_impl.hpp>
1112
#include <CL/sycl/event.hpp>
1213

1314
#include <memory>

0 commit comments

Comments
 (0)