Skip to content

Commit 8ab6d7f

Browse files
author
Alexander Batashev
authored
[SYCL] Refactor USM allocator to improve ABI stability (#1064)
Signed-off-by: Alexander Batashev <[email protected]>
1 parent ba30c3b commit 8ab6d7f

File tree

1 file changed

+39
-32
lines changed

1 file changed

+39
-32
lines changed

sycl/include/CL/sycl/usm/usm_allocator.hpp

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#pragma once
99

1010
#include <CL/sycl/context.hpp>
11-
#include <CL/sycl/detail/usm_impl.hpp>
1211
#include <CL/sycl/device.hpp>
1312
#include <CL/sycl/exception.hpp>
1413
#include <CL/sycl/queue.hpp>
@@ -20,6 +19,11 @@
2019
__SYCL_INLINE namespace cl {
2120
namespace sycl {
2221

22+
// Forward declarations.
23+
void *aligned_alloc(size_t alignment, size_t size, const device &dev,
24+
const context &ctxt, usm::alloc kind);
25+
void free(void *ptr, const context &ctxt);
26+
2327
template <typename T, usm::alloc AllocKind, size_t Alignment = 0>
2428
class usm_allocator {
2529
public:
@@ -36,14 +40,19 @@ class usm_allocator {
3640

3741
usm_allocator() = delete;
3842
usm_allocator(const context &Ctxt, const device &Dev)
39-
: mContext(Ctxt), mDevice(Dev) {}
43+
: MContext(Ctxt), MDevice(Dev) {}
4044
usm_allocator(const queue &Q)
41-
: mContext(Q.get_context()), mDevice(Q.get_device()) {}
45+
: MContext(Q.get_context()), MDevice(Q.get_device()) {}
4246
usm_allocator(const usm_allocator &Other)
43-
: mContext(Other.mContext), mDevice(Other.mDevice) {}
44-
45-
// Construct an object
46-
// Note: AllocKind == alloc::device is not allowed
47+
: MContext(Other.MContext), MDevice(Other.MDevice) {}
48+
49+
/// Constructs an object on memory pointed by Ptr.
50+
///
51+
/// Note: AllocKind == alloc::device is not allowed.
52+
///
53+
/// @param Ptr is a pointer to memory that will be used to construct the
54+
/// object.
55+
/// @param Val is a value to initialize the newly constructed object.
4756
template <
4857
usm::alloc AllocT = AllocKind,
4958
typename std::enable_if<AllocT != usm::alloc::device, int>::type = 0>
@@ -59,8 +68,11 @@ class usm_allocator {
5968
"Device pointers do not support construct on host");
6069
}
6170

62-
// Destroy an object
63-
// Note:: AllocKind == alloc::device is not allowed
71+
/// Destroys an object.
72+
///
73+
/// Note:: AllocKind == alloc::device is not allowed
74+
///
75+
/// @param Ptr is a pointer to memory where the object resides.
6476
template <
6577
usm::alloc AllocT = AllocKind,
6678
typename std::enable_if<AllocT != usm::alloc::device, int>::type = 0>
@@ -76,7 +88,10 @@ class usm_allocator {
7688
"Device pointers do not support destroy on host");
7789
}
7890

79-
// Note:: AllocKind == alloc::device is not allowed
91+
/// Note:: AllocKind == alloc::device is not allowed.
92+
///
93+
/// @param Val is a reference to object.
94+
/// @return an address of the object referenced by Val.
8095
template <
8196
usm::alloc AllocT = AllocKind,
8297
typename std::enable_if<AllocT != usm::alloc::device, int>::type = 0>
@@ -107,35 +122,27 @@ class usm_allocator {
107122
"Device pointers do not support address on host");
108123
}
109124

110-
// Allocate memory
111-
template <
112-
usm::alloc AllocT = AllocKind,
113-
typename std::enable_if<AllocT == usm::alloc::host, int>::type = 0>
114-
pointer allocate(size_t Size) {
115-
auto Result = reinterpret_cast<pointer>(detail::usm::alignedAllocHost(
116-
getAlignment(), Size * sizeof(value_type), mContext, AllocKind));
117-
if (!Result) {
118-
throw memory_allocation_error();
119-
}
120-
return Result;
121-
}
125+
/// Allocates memory.
126+
///
127+
/// @param NumberOfElements is a count of elements to allocate memory for.
128+
pointer allocate(size_t NumberOfElements) {
122129

123-
template <usm::alloc AllocT = AllocKind,
124-
typename std::enable_if<AllocT != usm::alloc::host, int>::type = 0>
125-
pointer allocate(size_t Size) {
126130
auto Result = reinterpret_cast<pointer>(
127-
detail::usm::alignedAlloc(getAlignment(), Size * sizeof(value_type),
128-
mContext, mDevice, AllocKind));
131+
aligned_alloc(getAlignment(), NumberOfElements * sizeof(value_type),
132+
MDevice, MContext, AllocKind));
129133
if (!Result) {
130134
throw memory_allocation_error();
131135
}
132136
return Result;
133137
}
134138

135-
// Deallocate memory
136-
void deallocate(pointer Ptr, size_t size) {
139+
/// Deallocates memory.
140+
///
141+
/// @param Ptr is a pointer to memory being deallocated.
142+
/// @param Size is a number of elements previously passed to allocate.
143+
void deallocate(pointer Ptr, size_t Size) {
137144
if (Ptr) {
138-
detail::usm::free(Ptr, mContext);
145+
free(Ptr, MContext);
139146
}
140147
}
141148

@@ -151,8 +158,8 @@ class usm_allocator {
151158
return Alignment;
152159
}
153160

154-
const context mContext;
155-
const device mDevice;
161+
const context MContext;
162+
const device MDevice;
156163
};
157164

158165
} // namespace sycl

0 commit comments

Comments
 (0)