Skip to content

Commit 01869a0

Browse files
jbrodmanromanovvlad
authored andcommitted
[SYCL][USM] Fix bug with malloc(..., kind) impl and host allocations (#691)
Also rename host-only alignedAlloc to be less ambiguous Signed-off-by: James Brodman <[email protected]>
1 parent b207160 commit 01869a0

File tree

4 files changed

+98
-10
lines changed

4 files changed

+98
-10
lines changed

sycl/include/CL/sycl/detail/usm_impl.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ namespace usm {
1919
void *alignedAlloc(size_t Alignment, size_t Bytes, const context &Ctxt,
2020
const device &Dev, cl::sycl::usm::alloc Kind);
2121

22-
void *alignedAlloc(size_t Alignment, size_t Bytes, const context &Ctxt,
23-
cl::sycl::usm::alloc Kind);
22+
void *alignedAllocHost(size_t Alignment, size_t Bytes, const context &Ctxt,
23+
cl::sycl::usm::alloc Kind);
2424

2525
void free(void *Ptr, const context &Ctxt);
2626

sycl/include/CL/sycl/usm/usm_allocator.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ class usm_allocator {
109109
usm::alloc AllocT = AllocKind,
110110
typename std::enable_if<AllocT == usm::alloc::host, int>::type = 0>
111111
pointer allocate(size_t Size) {
112-
auto Result = reinterpret_cast<pointer>(detail::usm::alignedAlloc(
112+
auto Result = reinterpret_cast<pointer>(detail::usm::alignedAllocHost(
113113
getAlignment(), Size * sizeof(value_type), mContext, AllocKind));
114114
if (!Result) {
115115
throw memory_allocation_error();

sycl/source/detail/usm/usm_impl.cpp

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ using alloc = cl::sycl::usm::alloc;
2323
namespace detail {
2424
namespace usm {
2525

26-
void *alignedAlloc(size_t Alignment, size_t Size, const context &Ctxt,
27-
alloc Kind) {
26+
void *alignedAllocHost(size_t Alignment, size_t Size, const context &Ctxt,
27+
alloc Kind) {
2828
void *RetVal = nullptr;
2929
if (Ctxt.is_host()) {
3030
if (!Alignment) {
@@ -118,7 +118,7 @@ void *alignedAlloc(size_t Alignment, size_t Size, const context &Ctxt,
118118
}
119119
return RetVal;
120120
}
121-
121+
122122
void free(void *Ptr, const context &Ctxt) {
123123
if (Ctxt.is_host()) {
124124
// need to use alignedFree here for Windows
@@ -153,15 +153,15 @@ void free(void *ptr, const context &Ctxt) {
153153
// Restricted USM
154154
///
155155
void *malloc_host(size_t Size, const context &Ctxt) {
156-
return detail::usm::alignedAlloc(0, Size, Ctxt, alloc::host);
156+
return detail::usm::alignedAllocHost(0, Size, Ctxt, alloc::host);
157157
}
158158

159159
void *malloc_shared(size_t Size, const device &Dev, const context &Ctxt) {
160160
return detail::usm::alignedAlloc(0, Size, Ctxt, Dev, alloc::shared);
161161
}
162162

163163
void *aligned_alloc_host(size_t Alignment, size_t Size, const context &Ctxt) {
164-
return detail::usm::alignedAlloc(Alignment, Size, Ctxt, alloc::host);
164+
return detail::usm::alignedAllocHost(Alignment, Size, Ctxt, alloc::host);
165165
}
166166

167167
void *aligned_alloc_shared(size_t Alignment, size_t Size, const device &Dev,
@@ -172,12 +172,28 @@ void *aligned_alloc_shared(size_t Alignment, size_t Size, const device &Dev,
172172
// single form
173173

174174
void *malloc(size_t Size, const device &Dev, const context &Ctxt, alloc Kind) {
175-
return detail::usm::alignedAlloc(0, Size, Ctxt, Dev, Kind);
175+
void *RetVal = nullptr;
176+
177+
if (Kind == alloc::host) {
178+
RetVal = detail::usm::alignedAllocHost(0, Size, Ctxt, Kind);
179+
} else {
180+
RetVal = detail::usm::alignedAlloc(0, Size, Ctxt, Dev, Kind);
181+
}
182+
183+
return RetVal;
176184
}
177185

178186
void *aligned_alloc(size_t Alignment, size_t Size, const device &Dev,
179187
const context &Ctxt, alloc Kind) {
180-
return detail::usm::alignedAlloc(Alignment, Size, Ctxt, Dev, Kind);
188+
void *RetVal = nullptr;
189+
190+
if (Kind == alloc::host) {
191+
RetVal = detail::usm::alignedAllocHost(Alignment, Size, Ctxt, Kind);
192+
} else {
193+
RetVal = detail::usm::alignedAlloc(Alignment, Size, Ctxt, Dev, Kind);
194+
}
195+
196+
return RetVal;
181197
}
182198

183199
} // namespace sycl

sycl/test/usm/mixed2.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// RUN: %clangxx -fsycl %s -o %t1.out
2+
// RUN: env SYCL_DEVICE_TYPE=HOST %t1.out
3+
// RUN: %CPU_RUN_PLACEHOLDER %t1.out
4+
// RUN: %GPU_RUN_PLACEHOLDER %t1.out
5+
6+
//==------------------- mixed2.cpp - Mixed Memory test ---------------------==//
7+
//
8+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9+
// See https://llvm.org/LICENSE.txt for license information.
10+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include <CL/sycl.hpp>
15+
16+
using namespace cl::sycl;
17+
18+
class foo;
19+
int main() {
20+
int *darray = nullptr;
21+
int *sarray = nullptr;
22+
int *harray = nullptr;
23+
const int N = 4;
24+
const int MAGIC_NUM = 42;
25+
26+
queue q;
27+
auto dev = q.get_device();
28+
auto ctxt = q.get_context();
29+
30+
darray = (int *)malloc(N * sizeof(int), dev, ctxt, usm::alloc::device);
31+
if (darray == nullptr) {
32+
return -1;
33+
}
34+
sarray = (int *)malloc(N * sizeof(int), dev, ctxt, usm::alloc::shared);
35+
36+
if (sarray == nullptr) {
37+
return -1;
38+
}
39+
40+
harray = (int *)malloc(N * sizeof(int), dev, ctxt, usm::alloc::host);
41+
if (harray == nullptr) {
42+
return -1;
43+
}
44+
for (int i = 0; i < N; i++) {
45+
sarray[i] = MAGIC_NUM - 1;
46+
harray[i] = 1;
47+
}
48+
49+
auto e0 = q.memset(darray, 0, N * sizeof(int));
50+
e0.wait();
51+
52+
auto e1 = q.submit([=](handler &cgh) {
53+
cgh.single_task<class foo>([=]() {
54+
for (int i = 0; i < N; i++) {
55+
sarray[i] += darray[i] + harray[i];
56+
}
57+
});
58+
});
59+
60+
e1.wait();
61+
62+
for (int i = 0; i < N; i++) {
63+
if (sarray[i] != MAGIC_NUM) {
64+
return -1;
65+
}
66+
}
67+
free(darray, ctxt);
68+
free(sarray, ctxt);
69+
free(harray, ctxt);
70+
71+
return 0;
72+
}

0 commit comments

Comments
 (0)