Skip to content

Commit 016c9f5

Browse files
committed
[SYCL] fixed USM malloc_shared and free to handle zero byte
Signed-off-by: Byoungro So <[email protected]>
1 parent e6ce614 commit 016c9f5

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed

sycl/source/detail/usm/usm_impl.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ namespace usm {
2727
void *alignedAllocHost(size_t Alignment, size_t Size, const context &Ctxt,
2828
alloc Kind) {
2929
void *RetVal = nullptr;
30+
if (Size == 0)
31+
return nullptr;
3032
if (Ctxt.is_host()) {
3133
if (!Alignment) {
3234
// worst case default
@@ -72,6 +74,8 @@ void *alignedAllocHost(size_t Alignment, size_t Size, const context &Ctxt,
7274
void *alignedAlloc(size_t Alignment, size_t Size, const context &Ctxt,
7375
const device &Dev, alloc Kind) {
7476
void *RetVal = nullptr;
77+
if (Size == 0)
78+
return nullptr;
7579
if (Ctxt.is_host()) {
7680
if (Kind == alloc::unknown) {
7781
RetVal = nullptr;
@@ -126,6 +130,8 @@ void *alignedAlloc(size_t Alignment, size_t Size, const context &Ctxt,
126130
}
127131

128132
void free(void *Ptr, const context &Ctxt) {
133+
if (Ptr == nullptr)
134+
return;
129135
if (Ctxt.is_host()) {
130136
// need to use alignedFree here for Windows
131137
detail::OSUtil::alignedFree(Ptr);
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// RUN: %clangxx -fsycl %s -o %t.out
2+
// RUN: env SYCL_DEVICE_TYPE=CPU %t.out
3+
// RUN: env SYCL_DEVICE_TYPE=GPU %t.out
4+
// RUN: env SYCL_DEVICE_TYPE=HOST %t.out
5+
6+
7+
//==-------------- usm_free.cpp - SYCL USM free malloc_shared and free test -------------==//
8+
//
9+
// This test checks if users will successfully allocate 160, 0, and -16 bytes of shared
10+
// memory, and also test user can call free() without worrying about nullptr or invalid
11+
// memory descriptor returned from malloc.
12+
//==-------------------------------------------------------------------------------------==//
13+
14+
#include <CL/sycl.hpp>
15+
#include <iostream>
16+
#include <stdlib.h>
17+
using namespace cl::sycl;
18+
19+
int main(int argc, char * argv[]) {
20+
auto exception_handler = [](cl::sycl::exception_list exceptions) {
21+
for (std::exception_ptr const &e : exceptions) {
22+
try {
23+
std::rethrow_exception(e);
24+
}
25+
catch (cl::sycl::exception const &e) {
26+
std::cout << "Caught asynchronous SYCL "
27+
"exception:\n"
28+
<< e.what() << std::endl;
29+
}
30+
}
31+
};
32+
33+
queue myQueue(default_selector{}, exception_handler);
34+
std::cout << "Device: " << myQueue.get_device().get_info<info::device::name>()
35+
<< std::endl;
36+
37+
double *ia = (double *)malloc_shared(160, myQueue);
38+
double *ja = (double *)malloc_shared(0, myQueue);
39+
double *result = (double *)malloc_shared(-16, myQueue);
40+
41+
std::cout << "ia : " << ia << " ja: " << ja << " result : " << result << std::endl;
42+
43+
// followings should not throws CL_INVALID_VALUE
44+
free(ia, myQueue);
45+
free(nullptr);
46+
free(ja, myQueue);
47+
free(result, myQueue);
48+
49+
return 0;
50+
}

0 commit comments

Comments
 (0)