16
16
#include " context.hpp"
17
17
#include " device.hpp"
18
18
#include " memory.hpp"
19
+ #include " ur2offload.hpp"
19
20
20
21
UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreate (
21
22
ur_context_handle_t hContext, ur_mem_flags_t flags, size_t size,
@@ -33,14 +34,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreate(
33
34
auto AllocMode = BufferMem::AllocMode::Default;
34
35
35
36
if (flags & UR_MEM_FLAG_ALLOC_HOST_POINTER) {
36
- olMemAlloc (OffloadDevice, OL_ALLOC_TYPE_HOST, size, &HostPtr);
37
+ auto Res = olMemAlloc (OffloadDevice, OL_ALLOC_TYPE_HOST, size, &HostPtr);
38
+ if (Res) {
39
+ return offloadResultToUR (Res);
40
+ }
37
41
// TODO: We (probably) need something like cuMemHostGetDevicePointer
38
42
// for this to work everywhere. For now assume the managed host pointer is
39
43
// device-accessible.
40
44
Ptr = HostPtr;
41
45
AllocMode = BufferMem::AllocMode::AllocHostPtr;
42
46
} else {
43
- olMemAlloc (OffloadDevice, OL_ALLOC_TYPE_DEVICE, size, &Ptr);
47
+ auto Res = olMemAlloc (OffloadDevice, OL_ALLOC_TYPE_DEVICE, size, &Ptr);
48
+ if (Res) {
49
+ return offloadResultToUR (Res);
50
+ }
44
51
if (flags & UR_MEM_FLAG_ALLOC_COPY_HOST_POINTER) {
45
52
AllocMode = BufferMem::AllocMode::CopyIn;
46
53
}
@@ -51,8 +58,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreate(
51
58
hContext, ParentBuffer, flags, AllocMode, Ptr, HostPtr, size});
52
59
53
60
if (PerformInitialCopy) {
54
- olMemcpy (nullptr , Ptr, OffloadDevice, HostPtr, hContext->OffloadHost , size,
55
- nullptr );
61
+ auto Res = olMemcpy (nullptr , Ptr, OffloadDevice, HostPtr,
62
+ hContext->OffloadHost , size, nullptr );
63
+ if (Res) {
64
+ return offloadResultToUR (Res);
65
+ }
56
66
}
57
67
58
68
*phBuffer = URMemObj.release ();
@@ -74,7 +84,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemRelease(ur_mem_handle_t hMem) {
74
84
if (hMem->MemType == ur_mem_handle_t_::Type::Buffer) {
75
85
// TODO: Handle registered host memory
76
86
auto &BufferImpl = std::get<BufferMem>(MemObjPtr->Mem );
77
- olMemFree (BufferImpl.Ptr );
87
+ auto Res = olMemFree (BufferImpl.Ptr );
88
+ if (Res) {
89
+ return offloadResultToUR (Res);
90
+ }
78
91
}
79
92
80
93
return UR_RESULT_SUCCESS;
0 commit comments