Skip to content

Commit 55b8dce

Browse files
Fznamznonvladimirlaz
authored andcommitted
[SYCL] Add support for custom allocators in buffer
Added missed constructors and get_allocator method. Signed-off-by: Mariya Podchishchaeva <[email protected]>
1 parent 1b96030 commit 55b8dce

File tree

4 files changed

+59
-51
lines changed

4 files changed

+59
-51
lines changed

sycl/include/CL/sycl/accessor.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,9 @@ SYCL_ACCESSOR_IMPL(!isTargetHostAccess(accessTarget) &&
150150
// reinterpret casting while setting kernel arguments in order to get cl_mem
151151
// value from the buffer regardless of the accessor's dimensionality.
152152
#ifndef __SYCL_DEVICE_ONLY__
153-
detail::buffer_impl<buffer_allocator<char>> *m_Buf = nullptr;
153+
detail::buffer_impl<buffer_allocator> *m_Buf = nullptr;
154154
#else
155-
char padding[sizeof(detail::buffer_impl<buffer_allocator<char>> *)];
155+
char padding[sizeof(detail::buffer_impl<buffer_allocator> *)];
156156
#endif // __SYCL_DEVICE_ONLY__
157157

158158
dataT *Data;
@@ -185,9 +185,9 @@ SYCL_ACCESSOR_IMPL(!isTargetHostAccess(accessTarget) &&
185185
// reinterpret casting while setting kernel arguments in order to get cl_mem
186186
// value from the buffer regardless of the accessor's dimensionality.
187187
#ifndef __SYCL_DEVICE_ONLY__
188-
detail::buffer_impl<buffer_allocator<char>> *m_Buf = nullptr;
188+
detail::buffer_impl<buffer_allocator> *m_Buf = nullptr;
189189
#else
190-
char padding[sizeof(detail::buffer_impl<buffer_allocator<char>> *)];
190+
char padding[sizeof(detail::buffer_impl<buffer_allocator> *)];
191191
#endif // __SYCL_DEVICE_ONLY__
192192

193193
dataT *Data;

sycl/include/CL/sycl/buffer.hpp

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class queue;
2121
template <int dimentions> class range;
2222

2323
template <typename T, int dimensions = 1,
24-
typename AllocatorT = cl::sycl::buffer_allocator<char>>
24+
typename AllocatorT = cl::sycl::buffer_allocator>
2525
class buffer {
2626
public:
2727
using value_type = T;
@@ -36,11 +36,11 @@ class buffer {
3636
get_count() * sizeof(T), propList);
3737
}
3838

39-
// buffer(const range<dimensions> &bufferRange, AllocatorT allocator,
40-
// const property_list &propList = {}) {
41-
// impl = std::make_shared<detail::buffer_impl>(bufferRange, allocator,
42-
// propList);
43-
// }
39+
buffer(const range<dimensions> &bufferRange, AllocatorT allocator,
40+
const property_list &propList = {}) {
41+
impl = std::make_shared<detail::buffer_impl<AllocatorT>>(
42+
get_count() * sizeof(T), propList, allocator);
43+
}
4444

4545
buffer(T *hostData, const range<dimensions> &bufferRange,
4646
const property_list &propList = {})
@@ -49,11 +49,11 @@ class buffer {
4949
hostData, get_count() * sizeof(T), propList);
5050
}
5151

52-
// buffer(T *hostData, const range<dimensions> &bufferRange,
53-
// AllocatorT allocator, const property_list &propList = {}) {
54-
// impl = std::make_shared<detail::buffer_impl>(hostData, bufferRange,
55-
// allocator, propList);
56-
// }
52+
buffer(T *hostData, const range<dimensions> &bufferRange,
53+
AllocatorT allocator, const property_list &propList = {}) {
54+
impl = std::make_shared<detail::buffer_impl<AllocatorT>>(
55+
hostData, get_count() * sizeof(T), propList, allocator);
56+
}
5757

5858
buffer(const T *hostData, const range<dimensions> &bufferRange,
5959
const property_list &propList = {})
@@ -62,18 +62,18 @@ class buffer {
6262
hostData, get_count() * sizeof(T), propList);
6363
}
6464

65-
// buffer(const T *hostData, const range<dimensions> &bufferRange,
66-
// AllocatorT allocator, const property_list &propList = {}) {
67-
// impl = std::make_shared<detail::buffer_impl>(hostData, bufferRange,
68-
// allocator, propList);
69-
// }
65+
buffer(const T *hostData, const range<dimensions> &bufferRange,
66+
AllocatorT allocator, const property_list &propList = {}) {
67+
impl = std::make_shared<detail::buffer_impl<AllocatorT>>(
68+
hostData, get_count() * sizeof(T), propList, allocator);
69+
}
7070

71-
// buffer(const shared_ptr_class<T> &hostData,
72-
// const range<dimensions> &bufferRange, AllocatorT allocator,
73-
// const property_list &propList = {}) {
74-
// impl = std::make_shared<detail::buffer_impl>(hostData, bufferRange,
75-
// allocator, propList);
76-
// }
71+
buffer(const shared_ptr_class<T> &hostData,
72+
const range<dimensions> &bufferRange, AllocatorT allocator,
73+
const property_list &propList = {}) {
74+
impl = std::make_shared<detail::buffer_impl<AllocatorT>>(
75+
hostData, get_count() * sizeof(T), propList, allocator);
76+
}
7777

7878
buffer(const shared_ptr_class<T> &hostData,
7979
const range<dimensions> &bufferRange,
@@ -83,12 +83,13 @@ class buffer {
8383
hostData, get_count() * sizeof(T), propList);
8484
}
8585

86-
// template <class InputIterator>
87-
// buffer<T, 1>(InputIterator first, InputIterator last, AllocatorT allocator,
88-
// const property_list &propList = {}) {
89-
// impl = std::make_shared<detail::buffer_impl>(first, last, allocator,
90-
// propList);
91-
// }
86+
template <class InputIterator>
87+
buffer(InputIterator first, InputIterator last, AllocatorT allocator,
88+
const property_list &propList = {})
89+
: Range(range<1>(std::distance(first, last))) {
90+
impl = std::make_shared<detail::buffer_impl<AllocatorT>>(
91+
first, last, get_count() * sizeof(T), propList, allocator);
92+
}
9293

9394
template <class InputIterator, int N = dimensions,
9495
typename = std::enable_if<N == 1>>
@@ -135,7 +136,7 @@ class buffer {
135136

136137
size_t get_size() const { return impl->get_size(); }
137138

138-
// AllocatorT get_allocator() const { return impl->get_allocator(); }
139+
AllocatorT get_allocator() const { return impl->get_allocator(); }
139140

140141
template <access::mode mode,
141142
access::target target = access::target::global_buffer>

sycl/include/CL/sycl/detail/buffer_impl.hpp

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,18 @@ class handler;
3737
class queue;
3838
template <int dimentions> class id;
3939
template <int dimentions> class range;
40-
template <class T> using buffer_allocator = std::allocator<T>;
40+
using buffer_allocator = std::allocator<char>;
4141
namespace detail {
4242
template <typename AllocatorT> class buffer_impl {
4343
public:
44-
buffer_impl(const size_t sizeInBytes, const property_list &propList)
45-
: buffer_impl((void *)nullptr, sizeInBytes, propList) {}
44+
buffer_impl(const size_t sizeInBytes, const property_list &propList,
45+
AllocatorT allocator = AllocatorT())
46+
: buffer_impl((void *)nullptr, sizeInBytes, propList, allocator) {}
4647

4748
buffer_impl(void *hostData, const size_t sizeInBytes,
48-
const property_list &propList)
49-
: SizeInBytes(sizeInBytes), Props(propList) {
49+
const property_list &propList,
50+
AllocatorT allocator = AllocatorT())
51+
: SizeInBytes(sizeInBytes), Props(propList), MAllocator(allocator) {
5052
if (Props.has_property<property::buffer::use_host_ptr>()) {
5153
BufPtr = hostData;
5254
} else {
@@ -62,8 +64,9 @@ template <typename AllocatorT> class buffer_impl {
6264

6365
// TODO temporary solution for allowing initialisation with const data
6466
buffer_impl(const void *hostData, const size_t sizeInBytes,
65-
const property_list &propList)
66-
: SizeInBytes(sizeInBytes), Props(propList) {
67+
const property_list &propList,
68+
AllocatorT allocator = AllocatorT())
69+
: SizeInBytes(sizeInBytes), Props(propList), MAllocator(allocator) {
6770
if (Props.has_property<property::buffer::use_host_ptr>()) {
6871
// TODO make this buffer read only
6972
BufPtr = const_cast<void *>(hostData);
@@ -79,8 +82,9 @@ template <typename AllocatorT> class buffer_impl {
7982

8083
template <typename T>
8184
buffer_impl(const shared_ptr_class<T> &hostData, const size_t sizeInBytes,
82-
const property_list &propList)
83-
: SizeInBytes(sizeInBytes), Props(propList) {
85+
const property_list &propList,
86+
AllocatorT allocator = AllocatorT())
87+
: SizeInBytes(sizeInBytes), Props(propList), MAllocator(allocator) {
8488
if (Props.has_property<property::buffer::use_host_ptr>()) {
8589
BufPtr = hostData.get();
8690
} else {
@@ -97,8 +101,9 @@ template <typename AllocatorT> class buffer_impl {
97101

98102
template <class InputIterator>
99103
buffer_impl(InputIterator first, InputIterator last, const size_t sizeInBytes,
100-
const property_list &propList)
101-
: SizeInBytes(sizeInBytes), Props(propList) {
104+
const property_list &propList,
105+
AllocatorT allocator = AllocatorT())
106+
: SizeInBytes(sizeInBytes), Props(propList), MAllocator(allocator) {
102107
if (Props.has_property<property::buffer::use_host_ptr>()) {
103108
// TODO next line looks unsafe
104109
BufPtr = &*first;
@@ -170,7 +175,7 @@ template <typename AllocatorT> class buffer_impl {
170175
throw cl::sycl::runtime_error(
171176
"set_final_data could not be used with interoperability buffer");
172177
static_assert(!std::is_const<Destination>::value,
173-
"Сan not write in a constant Destination. Destination should "
178+
"Can not write in a constant Destination. Destination should "
174179
"not be const.");
175180
uploadData = [this, final_data]() mutable {
176181
auto *Ptr =
@@ -182,6 +187,8 @@ template <typename AllocatorT> class buffer_impl {
182187
};
183188
}
184189

190+
AllocatorT get_allocator() const { return MAllocator; }
191+
185192
template <typename T, int dimensions, access::mode mode,
186193
access::target target = access::target::global_buffer>
187194
accessor<T, dimensions, mode, target, access::placeholder::false_t>
@@ -243,6 +250,7 @@ template <typename AllocatorT> class buffer_impl {
243250
// This field must be the first to guarantee that it's safe to use
244251
// reinterpret casting while setting kernel arguments in order to get cl_mem
245252
// value from the buffer regardless of its dimensionality.
253+
AllocatorT MAllocator;
246254
OpenCLMemState OCLState;
247255
bool OpenCLInterop = false;
248256
event AvailableEvent;

sycl/include/CL/sycl/detail/scheduler/scheduler.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ template <typename dataT, int dimensions, access::mode accessMode,
6161
void Node::addAccRequirement(
6262
accessor<dataT, dimensions, accessMode, accessTarget, isPlaceholder> &&Acc,
6363
int argIndex) {
64-
detail::buffer_impl<buffer_allocator<char>> *buf =
65-
Acc.__get_impl()->m_Buf;
64+
detail::buffer_impl<buffer_allocator> *buf = Acc.__get_impl()->m_Buf;
6665
addBufRequirement<accessMode, accessTarget>(*buf);
6766
addInteropArg(nullptr, buf->get_size(), argIndex,
6867
getReqForBuffer(m_Bufs, *buf));
@@ -134,7 +133,7 @@ void Node::addExplicitMemOp(
134133
auto *DestBase = Dest.__get_impl();
135134
assert(DestBase != nullptr &&
136135
"Accessor should have an initialized accessor_base");
137-
detail::buffer_impl<buffer_allocator<char>> *Buf = DestBase->m_Buf;
136+
detail::buffer_impl<buffer_allocator> *Buf = DestBase->m_Buf;
138137

139138
range<Dimensions> Range = DestBase->AccessRange;
140139
id<Dimensions> Offset = DestBase->Offset;
@@ -162,10 +161,10 @@ void Node::addExplicitMemOp(
162161
assert(DestBase != nullptr &&
163162
"Accessor should have an initialized accessor_base");
164163

165-
detail::buffer_impl<buffer_allocator<char>> *SrcBuf = SrcBase->m_Buf;
164+
detail::buffer_impl<buffer_allocator> *SrcBuf = SrcBase->m_Buf;
166165
assert(SrcBuf != nullptr &&
167166
"Accessor should have an initialized buffer_impl");
168-
detail::buffer_impl<buffer_allocator<char>> *DestBuf = DestBase->m_Buf;
167+
detail::buffer_impl<buffer_allocator> *DestBuf = DestBase->m_Buf;
169168
assert(DestBuf != nullptr &&
170169
"Accessor should have an initialized buffer_impl");
171170

@@ -195,7 +194,7 @@ void Scheduler::updateHost(
195194
auto *AccBase = Acc.__get_impl();
196195
assert(AccBase != nullptr &&
197196
"Accessor should have an initialized accessor_base");
198-
detail::buffer_impl<buffer_allocator<char>> *Buf = AccBase->m_Buf;
197+
detail::buffer_impl<buffer_allocator> *Buf = AccBase->m_Buf;
199198

200199
updateHost<mode, tgt>(*Buf, Event);
201200
}

0 commit comments

Comments
 (0)