Skip to content

Commit 94b36ac

Browse files
authored
[SYCL] Add device_ptr and host_ptr (#1864)
Currently a device backend can't trace from where a pointer allocated by USM comes: it can be either allocated on host or on device (it's just a pointer in OpenCL global address space). On FPGAs at least we can generate more efficient hardware code if the user tells us where the pointer can point. With this change users can create multi_ptr with specialized address space global_host or global_device that will proved to the compiler additional information to process load-store optimizations. Accessor pointers shall be also moved to global_device address spaces - otherwise backend would assume, that a pointer in global address space can access both host and device memory. Previously there were added global_device in global_host address spaces for OpenCL/SYCL in clang. With this patch device_space and host_space were added in the SYCL headers the are mapped into the new address spaces and aliases to multi_ptr instantiated with the space: device_ptr and host_ptr. Signed-off-by: Dmitry Sidorov <[email protected]>
1 parent eaf3396 commit 94b36ac

File tree

7 files changed

+137
-29
lines changed

7 files changed

+137
-29
lines changed

sycl/include/CL/sycl/access/access.hpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ enum class address_space : int {
4545
private_space = 0,
4646
global_space,
4747
constant_space,
48-
local_space
48+
local_space,
49+
global_device_space,
50+
global_host_space
4951
};
5052

5153
} // namespace access
@@ -103,11 +105,15 @@ constexpr bool modeWritesNewData(access::mode m) {
103105

104106
#ifdef __SYCL_DEVICE_ONLY__
105107
#define __OPENCL_GLOBAL_AS__ __attribute__((opencl_global))
108+
#define __OPENCL_GLOBAL_DEVICE_AS__ __attribute__((opencl_global_device))
109+
#define __OPENCL_GLOBAL_HOST_AS__ __attribute__((opencl_global_host))
106110
#define __OPENCL_LOCAL_AS__ __attribute__((opencl_local))
107111
#define __OPENCL_CONSTANT_AS__ __attribute__((opencl_constant))
108112
#define __OPENCL_PRIVATE_AS__ __attribute__((opencl_private))
109113
#else
110114
#define __OPENCL_GLOBAL_AS__
115+
#define __OPENCL_GLOBAL_DEVICE_AS__
116+
#define __OPENCL_GLOBAL_HOST_AS__
111117
#define __OPENCL_LOCAL_AS__
112118
#define __OPENCL_CONSTANT_AS__
113119
#define __OPENCL_PRIVATE_AS__
@@ -141,6 +147,16 @@ struct PtrValueType<ElementType, access::address_space::global_space> {
141147
using type = __OPENCL_GLOBAL_AS__ ElementType;
142148
};
143149

150+
template <typename ElementType>
151+
struct PtrValueType<ElementType, access::address_space::global_device_space> {
152+
using type = __OPENCL_GLOBAL_DEVICE_AS__ ElementType;
153+
};
154+
155+
template <typename ElementType>
156+
struct PtrValueType<ElementType, access::address_space::global_host_space> {
157+
using type = __OPENCL_GLOBAL_HOST_AS__ ElementType;
158+
};
159+
144160
template <typename ElementType>
145161
struct PtrValueType<ElementType, access::address_space::constant_space> {
146162
// Current implementation of address spaces handling leads to possibility
@@ -171,6 +187,14 @@ struct remove_AS<__OPENCL_GLOBAL_AS__ T> {
171187
typedef T type;
172188
};
173189

190+
template <class T> struct remove_AS<__OPENCL_GLOBAL_DEVICE_AS__ T> {
191+
typedef T type;
192+
};
193+
194+
template <class T> struct remove_AS<__OPENCL_GLOBAL_HOST_AS__ T> {
195+
typedef T type;
196+
};
197+
174198
template <class T>
175199
struct remove_AS<__OPENCL_PRIVATE_AS__ T> {
176200
typedef T type;
@@ -188,6 +212,8 @@ struct remove_AS<__OPENCL_CONSTANT_AS__ T> {
188212
#endif
189213

190214
#undef __OPENCL_GLOBAL_AS__
215+
#undef __OPENCL_GLOBAL_DEVICE_AS__
216+
#undef __OPENCL_GLOBAL_HOST_AS__
191217
#undef __OPENCL_LOCAL_AS__
192218
#undef __OPENCL_CONSTANT_AS__
193219
#undef __OPENCL_PRIVATE_AS__

sycl/include/CL/sycl/atomic.hpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,10 @@ template <typename T> struct IsValidAtomicType {
4646
};
4747

4848
template <cl::sycl::access::address_space AS> struct IsValidAtomicAddressSpace {
49-
static constexpr bool value = (AS == access::address_space::global_space ||
50-
AS == access::address_space::local_space);
49+
static constexpr bool value =
50+
(AS == access::address_space::global_space ||
51+
AS == access::address_space::local_space ||
52+
AS == access::address_space::global_device_space);
5153
};
5254

5355
// Type trait to translate a cl::sycl::access::address_space to
@@ -56,6 +58,10 @@ template <access::address_space AS> struct GetSpirvMemoryScope {};
5658
template <> struct GetSpirvMemoryScope<access::address_space::global_space> {
5759
static constexpr auto scope = __spv::Scope::Device;
5860
};
61+
template <>
62+
struct GetSpirvMemoryScope<access::address_space::global_device_space> {
63+
static constexpr auto scope = __spv::Scope::Device;
64+
};
5965
template <> struct GetSpirvMemoryScope<access::address_space::local_space> {
6066
static constexpr auto scope = __spv::Scope::Workgroup;
6167
};
@@ -168,12 +174,12 @@ template <typename T, access::address_space addressSpace =
168174
access::address_space::global_space>
169175
class atomic {
170176
static_assert(detail::IsValidAtomicType<T>::value,
171-
"Invalid SYCL atomic type. Valid types are: int, "
172-
"unsigned int, long, unsigned long, long long, unsigned "
177+
"Invalid SYCL atomic type. Valid types are: int, "
178+
"unsigned int, long, unsigned long, long long, unsigned "
173179
"long long, float");
174180
static_assert(detail::IsValidAtomicAddressSpace<addressSpace>::value,
175-
"Invalid SYCL atomic address_space. Valid address spaces are: "
176-
"global_space, local_space");
181+
"Invalid SYCL atomic address_space. Valid address spaces are: "
182+
"global_space, local_space, global_device_space");
177183
static constexpr auto SpirvScope =
178184
detail::GetSpirvMemoryScope<addressSpace>::scope;
179185

sycl/include/CL/sycl/detail/generic_type_lists.hpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -361,21 +361,25 @@ using nan_list = type_list<gtl::unsigned_short_list, gtl::unsigned_int_list,
361361
} // namespace gtl
362362
namespace gvl {
363363
// address spaces
364-
using all_address_space_list =
365-
address_space_list<access::address_space::local_space,
366-
access::address_space::global_space,
367-
access::address_space::private_space,
368-
access::address_space::constant_space>;
364+
using all_address_space_list = address_space_list<
365+
access::address_space::local_space, access::address_space::global_space,
366+
access::address_space::private_space, access::address_space::constant_space,
367+
access::address_space::global_device_space,
368+
access::address_space::global_host_space>;
369369

370370
using nonconst_address_space_list =
371371
address_space_list<access::address_space::local_space,
372372
access::address_space::global_space,
373-
access::address_space::private_space>;
373+
access::address_space::private_space,
374+
access::address_space::global_device_space,
375+
access::address_space::global_host_space>;
374376

375377
using nonlocal_address_space_list =
376378
address_space_list<access::address_space::global_space,
377379
access::address_space::private_space,
378-
access::address_space::constant_space>;
380+
access::address_space::constant_space,
381+
access::address_space::global_device_space,
382+
access::address_space::global_host_space>;
379383
} // namespace gvl
380384
} // namespace detail
381385
} // namespace sycl

sycl/include/CL/sycl/multi_ptr.hpp

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -108,17 +108,18 @@ template <typename ElementType, access::address_space Space> class multi_ptr {
108108
return reinterpret_cast<ReturnPtr>(m_Pointer)[index];
109109
}
110110

111-
// Only if Space == global_space
111+
// Only if Space == global_space || global_device_space
112112
template <int dimensions, access::mode Mode,
113113
access::placeholder isPlaceholder,
114114
access::address_space _Space = Space,
115115
typename = typename std::enable_if<
116116
_Space == Space &&
117-
Space == access::address_space::global_space>::type>
117+
(Space == access::address_space::global_space ||
118+
Space == access::address_space::global_device_space)>::type>
118119
multi_ptr(accessor<ElementType, dimensions, Mode,
119120
access::target::global_buffer, isPlaceholder>
120121
Accessor) {
121-
m_Pointer = (pointer_t)(Accessor.get_pointer().m_Pointer);
122+
m_Pointer = (pointer_t)(Accessor.get_pointer().get());
122123
}
123124

124125
// Only if Space == local_space
@@ -152,14 +153,17 @@ template <typename ElementType, access::address_space Space> class multi_ptr {
152153
// 2. from multi_ptr<ElementType, Space> to multi_ptr<const ElementType,
153154
// Space>
154155

155-
// Only if Space == global_space and element type is const
156-
template <
157-
int dimensions, access::mode Mode, access::placeholder isPlaceholder,
158-
access::address_space _Space = Space, typename ET = ElementType,
159-
typename = typename std::enable_if<
160-
_Space == Space && Space == access::address_space::global_space &&
161-
std::is_const<ET>::value &&
162-
std::is_same<ET, ElementType>::value>::type>
156+
// Only if Space == global_space || global_device_space and element type is
157+
// const
158+
template <int dimensions, access::mode Mode,
159+
access::placeholder isPlaceholder,
160+
access::address_space _Space = Space, typename ET = ElementType,
161+
typename = typename std::enable_if<
162+
_Space == Space &&
163+
(Space == access::address_space::global_space ||
164+
Space == access::address_space::global_device_space) &&
165+
std::is_const<ET>::value &&
166+
std::is_same<ET, ElementType>::value>::type>
163167
multi_ptr(accessor<typename std::remove_const<ET>::type, dimensions, Mode,
164168
access::target::global_buffer, isPlaceholder>
165169
Accessor)
@@ -345,12 +349,13 @@ template <access::address_space Space> class multi_ptr<void, Space> {
345349
return *this;
346350
}
347351

348-
// Only if Space == global_space
352+
// Only if Space == global_space || global_device_space
349353
template <typename ElementType, int dimensions, access::mode Mode,
350354
access::address_space _Space = Space,
351355
typename = typename std::enable_if<
352356
_Space == Space &&
353-
Space == access::address_space::global_space>::type>
357+
(Space == access::address_space::global_space ||
358+
Space == access::address_space::global_device_space)>::type>
354359
multi_ptr(
355360
accessor<ElementType, dimensions, Mode, access::target::global_buffer,
356361
access::placeholder::false_t>
@@ -466,12 +471,13 @@ class multi_ptr<const void, Space> {
466471
return *this;
467472
}
468473

469-
// Only if Space == global_space
474+
// Only if Space == global_space || global_device_space
470475
template <typename ElementType, int dimensions, access::mode Mode,
471476
access::address_space _Space = Space,
472477
typename = typename std::enable_if<
473478
_Space == Space &&
474-
Space == access::address_space::global_space>::type>
479+
(Space == access::address_space::global_space ||
480+
Space == access::address_space::global_device_space)>::type>
475481
multi_ptr(
476482
accessor<ElementType, dimensions, Mode, access::target::global_buffer,
477483
access::placeholder::false_t>

sycl/include/CL/sycl/pointers.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@ template <typename ElementType, access::address_space Space> class multi_ptr;
1919
template <typename ElementType>
2020
using global_ptr = multi_ptr<ElementType, access::address_space::global_space>;
2121

22+
template <typename ElementType>
23+
using device_ptr =
24+
multi_ptr<ElementType, access::address_space::global_device_space>;
25+
26+
template <typename ElementType>
27+
using host_ptr =
28+
multi_ptr<ElementType, access::address_space::global_host_space>;
29+
2230
template <typename ElementType>
2331
using local_ptr = multi_ptr<ElementType, access::address_space::local_space>;
2432

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// RUN: %clangxx -fsycl-device-only -Xclang -fsycl-is-device -emit-llvm %s -S -o %t.ll -I %sycl_include -Wno-sycl-strict -Xclang -verify-ignore-unexpected=note,warning
2+
// RUN: FileCheck %s --input-file %t.ll
3+
//
4+
// Check the address space of the pointer in multi_ptr class
5+
//
6+
// CHECK: %[[DEVPTR_T:.*]] = type { i8 addrspace(5)* }
7+
// CHECK: %[[HOSTPTR_T:.*]] = type { i8 addrspace(6)* }
8+
//
9+
// CHECK-LABEL: define {{.*}} spir_func i8 addrspace(4)* @{{.*}}multi_ptr{{.*}}
10+
// CHECK: %m_Pointer = getelementptr inbounds %[[DEVPTR_T]]
11+
// CHECK-NEXT: %[[DEVLOAD:[0-9]+]] = load i8 addrspace(5)*, i8 addrspace(5)* addrspace(4)* %m_Pointer
12+
// CHECK-NEXT: %[[DEVCAST:[0-9]+]] = addrspacecast i8 addrspace(5)* %[[DEVLOAD]] to i8 addrspace(4)*
13+
// ret i8 addrspace(4)* %[[DEVCAST]]
14+
//
15+
// CHECK-LABEL: define {{.*}} spir_func i8 addrspace(4)* @{{.*}}multi_ptr{{.*}}
16+
// CHECK: %m_Pointer = getelementptr inbounds %[[HOSTPTR_T]]
17+
// CHECK-NEXT: %[[HOSTLOAD:[0-9]+]] = load i8 addrspace(6)*, i8 addrspace(6)* addrspace(4)* %m_Pointer
18+
// CHECK-NEXT: %[[HOSTCAST:[0-9]+]] = addrspacecast i8 addrspace(6)* %[[HOSTLOAD]] to i8 addrspace(4)*
19+
// ret i8 addrspace(4)* %[[HOSTCAST]]
20+
21+
#include <CL/sycl.hpp>
22+
23+
using namespace cl::sycl;
24+
25+
int main() {
26+
cl::sycl::queue queue;
27+
{
28+
queue.submit([&](cl::sycl::handler &cgh) {
29+
cgh.single_task<class check_adress_space>([=]() {
30+
void *Ptr = nullptr;
31+
device_ptr<void> DevPtr(Ptr);
32+
host_ptr<void> HostPtr(Ptr);
33+
global_ptr<void> GlobPtr = global_ptr<void>(DevPtr);
34+
GlobPtr = global_ptr<void>(HostPtr);
35+
});
36+
});
37+
queue.wait();
38+
}
39+
40+
return 0;
41+
}

sycl/test/multi_ptr/multi_ptr.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ template <typename T> void testMultPtr() {
8282
auto local_ptr = make_ptr<T, access::address_space::local_space>(
8383
localAccessor.get_pointer());
8484

85+
// General conversions in multi_ptr class
8586
T *RawPtr = nullptr;
8687
global_ptr<T> ptr_4(RawPtr);
8788
ptr_4 = RawPtr;
@@ -92,6 +93,12 @@ template <typename T> void testMultPtr() {
9293

9394
ptr_6 = (void *)RawPtr;
9495

96+
// Explicit conversions for device_ptr/host_ptr to global_ptr
97+
device_ptr<void> ptr_7((void *)RawPtr);
98+
global_ptr<void> ptr_8 = global_ptr<void>(ptr_7);
99+
host_ptr<void> ptr_9((void *)RawPtr);
100+
global_ptr<void> ptr_10 = global_ptr<void>(ptr_9);
101+
95102
innerFunc<T>(wiID.get(0), ptr_1, ptr_2, local_ptr);
96103
});
97104
});
@@ -109,12 +116,14 @@ void testMultPtrArrowOperator() {
109116
point<T> data_1[1] = {1};
110117
point<T> data_2[1] = {2};
111118
point<T> data_3[1] = {3};
119+
point<T> data_4[1] = {4};
112120

113121
{
114122
range<1> numOfItems{1};
115123
buffer<point<T>, 1> bufferData_1(data_1, numOfItems);
116124
buffer<point<T>, 1> bufferData_2(data_2, numOfItems);
117125
buffer<point<T>, 1> bufferData_3(data_3, numOfItems);
126+
buffer<point<T>, 1> bufferData_4(data_4, numOfItems);
118127
queue myQueue;
119128
myQueue.submit([&](handler &cgh) {
120129
accessor<point<T>, 1, access::mode::read, access::target::global_buffer,
@@ -126,6 +135,9 @@ void testMultPtrArrowOperator() {
126135
accessor<point<T>, 1, access::mode::read_write, access::target::local,
127136
access::placeholder::false_t>
128137
accessorData_3(1, cgh);
138+
accessor<point<T>, 1, access::mode::read, access::target::global_buffer,
139+
access::placeholder::false_t>
140+
accessorData_4(bufferData_4, cgh);
129141

130142
cgh.single_task<class testMultPtrArrowOperatorKernel<T>>([=]() {
131143
auto ptr_1 = make_ptr<point<T>, access::address_space::global_space>(
@@ -134,17 +146,22 @@ void testMultPtrArrowOperator() {
134146
accessorData_2.get_pointer());
135147
auto ptr_3 = make_ptr<point<T>, access::address_space::local_space>(
136148
accessorData_3.get_pointer());
149+
auto ptr_4 = make_ptr<point<T>, access::address_space::global_device_space>(
150+
accessorData_4.get_pointer());
137151

138152
auto x1 = ptr_1->x;
139153
auto x2 = ptr_2->x;
140154
auto x3 = ptr_3->x;
155+
auto x4 = ptr_4->x;
141156

142157
static_assert(std::is_same<decltype(x1), T>::value,
143158
"Expected decltype(ptr_1->x) == T");
144159
static_assert(std::is_same<decltype(x2), T>::value,
145160
"Expected decltype(ptr_2->x) == T");
146161
static_assert(std::is_same<decltype(x3), T>::value,
147162
"Expected decltype(ptr_3->x) == T");
163+
static_assert(std::is_same<decltype(x4), T>::value,
164+
"Expected decltype(ptr_4->x) == T");
148165
});
149166
});
150167
}

0 commit comments

Comments
 (0)