Skip to content

Commit 44e49de

Browse files
[SYCL] Implement USM memcpy/memset on handlers.
Signed-off-by: Joshua Cranmer <[email protected]>
1 parent d0207ac commit 44e49de

File tree

8 files changed

+239
-3
lines changed

8 files changed

+239
-3
lines changed

sycl/doc/extensions/USM/USM.adoc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ class handler {
256256
...
257257
public:
258258
...
259-
event memcpy(void* dest, const void* src, size_t count);
259+
void memcpy(void* dest, const void* src, size_t count);
260260
};
261261
262262
class queue {
@@ -279,7 +279,7 @@ class handler {
279279
...
280280
public:
281281
...
282-
event memset(void* ptr, int value, size_t count);
282+
void memset(void* ptr, int value, size_t count);
283283
};
284284
285285
class queue {

sycl/include/CL/sycl/detail/cg.hpp

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,9 @@ class CG {
326326
COPY_ACC_TO_ACC,
327327
FILL,
328328
UPDATE_HOST,
329-
RUN_ON_HOST_INTEL
329+
RUN_ON_HOST_INTEL,
330+
COPY_USM,
331+
FILL_USM
330332
};
331333

332334
CG(CGTYPE Type, std::vector<std::vector<char>> ArgsStorage,
@@ -461,6 +463,51 @@ class CGUpdateHost : public CG {
461463
Requirement *getReqToUpdate() { return MPtr; }
462464
};
463465

466+
// The class which represents "copy" command group for USM pointers.
467+
class CGCopyUSM : public CG {
468+
void *MSrc;
469+
void *MDst;
470+
size_t MLength;
471+
472+
public:
473+
CGCopyUSM(void *Src, void *Dst, size_t Length,
474+
std::vector<std::vector<char>> ArgsStorage,
475+
std::vector<detail::AccessorImplPtr> AccStorage,
476+
std::vector<std::shared_ptr<const void>> SharedPtrStorage,
477+
std::vector<Requirement *> Requirements,
478+
std::vector<detail::EventImplPtr> Events)
479+
: CG(COPY_USM, std::move(ArgsStorage), std::move(AccStorage),
480+
std::move(SharedPtrStorage), std::move(Requirements),
481+
std::move(Events)),
482+
MSrc(Src), MDst(Dst), MLength(Length) {}
483+
484+
void *getSrc() { return MSrc; }
485+
void *getDst() { return MDst; }
486+
size_t getLength() { return MLength; }
487+
};
488+
489+
// The class which represents "fill" command group for USM pointers.
490+
class CGFillUSM : public CG {
491+
std::vector<char> MPattern;
492+
void *MDst;
493+
size_t MLength;
494+
495+
public:
496+
CGFillUSM(std::vector<char> Pattern, void *DstPtr, size_t Length,
497+
std::vector<std::vector<char>> ArgsStorage,
498+
std::vector<detail::AccessorImplPtr> AccStorage,
499+
std::vector<std::shared_ptr<const void>> SharedPtrStorage,
500+
std::vector<Requirement *> Requirements,
501+
std::vector<detail::EventImplPtr> Events)
502+
: CG(FILL_USM, std::move(ArgsStorage), std::move(AccStorage),
503+
std::move(SharedPtrStorage), std::move(Requirements),
504+
std::move(Events)),
505+
MPattern(std::move(Pattern)), MDst(DstPtr), MLength(Length) {}
506+
void *getDst() { return MDst; }
507+
size_t getLength() { return MLength; }
508+
int getFill() { return MPattern[0]; }
509+
};
510+
464511
} // namespace detail
465512
} // namespace sycl
466513
} // namespace cl

sycl/include/CL/sycl/detail/memory_manager.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,15 @@ class MemoryManager {
120120
static void unmap(SYCLMemObjI *SYCLMemObj, void *Mem, QueueImplPtr Queue,
121121
void *MappedPtr, std::vector<RT::PiEvent> DepEvents,
122122
bool UseExclusiveQueue, RT::PiEvent &OutEvent);
123+
124+
static void copy_usm(void *SrcMem, QueueImplPtr Queue, size_t Len,
125+
void *DstMem, std::vector<RT::PiEvent> DepEvents,
126+
bool UseExclusiveQueue, RT::PiEvent &OutEvent);
127+
128+
static void fill_usm(void *DstMem, QueueImplPtr Queue, size_t Len,
129+
int Pattern, std::vector<RT::PiEvent> DepEvents,
130+
RT::PiEvent &OutEvent);
131+
123132
};
124133
} // namespace detail
125134
} // namespace sycl

sycl/include/CL/sycl/handler.hpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ class handler {
164164
void *MSrcPtr = nullptr;
165165
// Pointer to the dest host memory or accessor(depends on command type).
166166
void *MDstPtr = nullptr;
167+
// Length to copy or fill (for USM operations).
168+
size_t MLength = 0;
167169
// Pattern that is used to fill memory object in case command type is fill.
168170
std::vector<char> MPattern;
169171
// Storage for a lambda or function object.
@@ -383,6 +385,18 @@ class handler {
383385
std::move(MSharedPtrStorage), std::move(MRequirements),
384386
std::move(MEvents)));
385387
break;
388+
case detail::CG::COPY_USM:
389+
CommandGroup.reset(new detail::CGCopyUSM(
390+
MSrcPtr, MDstPtr, MLength, std::move(MArgsStorage),
391+
std::move(MAccStorage), std::move(MSharedPtrStorage),
392+
std::move(MRequirements), std::move(MEvents)));
393+
break;
394+
case detail::CG::FILL_USM:
395+
CommandGroup.reset(new detail::CGFillUSM(
396+
std::move(MPattern), MDstPtr, MLength, std::move(MArgsStorage),
397+
std::move(MAccStorage), std::move(MSharedPtrStorage),
398+
std::move(MRequirements), std::move(MEvents)));
399+
break;
386400
case detail::CG::NONE:
387401
throw runtime_error("Command group submitted without a kernel or a "
388402
"explicit memory operation.");
@@ -1133,6 +1147,22 @@ class handler {
11331147
});
11341148
}
11351149
}
1150+
1151+
// Copy memory from the source to the destination.
1152+
void memcpy(void* Dest, const void* Src, size_t Count) {
1153+
MSrcPtr = const_cast<void *>(Src);
1154+
MDstPtr = Dest;
1155+
MLength = Count;
1156+
MCGType = detail::CG::COPY_USM;
1157+
}
1158+
1159+
// Fill the memory pointed to by the destination with the given bytes.
1160+
void memset(void *Dest, int Value, size_t Count) {
1161+
MDstPtr = Dest;
1162+
MPattern.push_back((char)Value);
1163+
MLength = Count;
1164+
MCGType = detail::CG::FILL_USM;
1165+
}
11361166
};
11371167
} // namespace sycl
11381168
} // namespace cl

sycl/source/detail/memory_manager.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <CL/sycl/detail/event_impl.hpp>
1111
#include <CL/sycl/detail/memory_manager.hpp>
1212
#include <CL/sycl/detail/queue_impl.hpp>
13+
#include <CL/sycl/detail/usm_dispatch.hpp>
1314

1415
#include <algorithm>
1516
#include <cassert>
@@ -462,6 +463,31 @@ void MemoryManager::unmap(SYCLMemObjI *SYCLMemObj, void *Mem,
462463
DepEvents.empty() ? nullptr : &DepEvents[0], &OutEvent));
463464
}
464465

466+
void MemoryManager::copy_usm(void *SrcMem, QueueImplPtr SrcQueue, size_t Len,
467+
void *DstMem, std::vector<RT::PiEvent> DepEvents,
468+
bool UseExclusiveQueue, RT::PiEvent &OutEvent) {
469+
RT::PiQueue Queue = UseExclusiveQueue
470+
? SrcQueue->getExclusiveQueueHandleRef()
471+
: SrcQueue->getHandleRef();
472+
473+
sycl::context Context = SrcQueue->get_context();
474+
std::shared_ptr<usm::USMDispatcher> USMDispatch =
475+
getSyclObjImpl(Context)->getUSMDispatch();
476+
PI_CHECK(USMDispatch->enqueueMemcpy(Queue,
477+
/* blocking */ false, DstMem, SrcMem, Len, DepEvents.size(),
478+
&DepEvents[0], &OutEvent));
479+
}
480+
481+
void MemoryManager::fill_usm(void *Mem, QueueImplPtr Queue, size_t Length,
482+
int Pattern, std::vector<RT::PiEvent> DepEvents,
483+
RT::PiEvent &OutEvent) {
484+
sycl::context Context = Queue->get_context();
485+
std::shared_ptr<usm::USMDispatcher> USMDispatch =
486+
getSyclObjImpl(Context)->getUSMDispatch();
487+
PI_CHECK(USMDispatch->enqueueMemset(Queue->getHandleRef(),
488+
Mem, Pattern, Length, DepEvents.size(), &DepEvents[0], &OutEvent));
489+
}
490+
465491
} // namespace detail
466492
} // namespace sycl
467493
} // namespace cl

sycl/source/detail/scheduler/commands.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,12 @@ void ExecCGCommand::printDot(std::ostream &Stream) const {
480480
case detail::CG::COPY_PTR_TO_ACC:
481481
Stream << "CG type: copy ptr to acc\\n";
482482
break;
483+
case detail::CG::COPY_USM:
484+
Stream << "CG type: copy usm\\n";
485+
break;
486+
case detail::CG::FILL_USM:
487+
Stream << "CG type: fill usm\\n";
488+
break;
483489
default:
484490
Stream << "CG type: unknown\\n";
485491
break;
@@ -766,6 +772,18 @@ cl_int ExecCGCommand::enqueueImp() {
766772

767773
return PI_SUCCESS;
768774
}
775+
case CG::CGTYPE::COPY_USM: {
776+
CGCopyUSM *Copy = (CGCopyUSM *)MCommandGroup.get();
777+
MemoryManager::copy_usm(Copy->getSrc(), MQueue, Copy->getLength(),
778+
Copy->getDst(), std::move(RawEvents), MUseExclusiveQueue, Event);
779+
return CL_SUCCESS;
780+
}
781+
case CG::CGTYPE::FILL_USM: {
782+
CGFillUSM *Fill = (CGFillUSM *)MCommandGroup.get();
783+
MemoryManager::fill_usm(Fill->getDst(), MQueue, Fill->getLength(),
784+
Fill->getFill(), std::move(RawEvents), Event);
785+
return CL_SUCCESS;
786+
}
769787
case CG::CGTYPE::NONE:
770788
default:
771789
throw runtime_error("CG type not implemented.");

sycl/test/usm/memcpy.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
//==---- memcpy.cpp - USM memcpy test --------------------------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// RUN: %clangxx -fsycl %s -o %t1.out -lOpenCL
9+
// RUN: %CPU_RUN_PLACEHOLDER %t1.out
10+
11+
#include <CL/sycl.hpp>
12+
13+
using namespace cl::sycl;
14+
15+
static constexpr int count = 100;
16+
17+
int main() {
18+
queue q([](exception_list el) {
19+
for (auto &e : el)
20+
std::rethrow_exception(e);
21+
});
22+
float *src = (float*)malloc_shared(sizeof(float) * count, q.get_device(),
23+
q.get_context());
24+
float *dest = (float*)malloc_shared(sizeof(float) * count, q.get_device(),
25+
q.get_context());
26+
for (int i = 0; i < count; i++)
27+
src[i] = i;
28+
29+
event init_copy = q.submit([&](handler &cgh) {
30+
cgh.memcpy(dest, src, sizeof(float) * count);
31+
});
32+
33+
q.submit([&](handler &cgh) {
34+
cgh.depends_on(init_copy);
35+
cgh.single_task<class double_dest>([=]() {
36+
for (int i = 0; i < count; i++)
37+
dest[i] *= 2;
38+
});
39+
});
40+
q.wait_and_throw();
41+
42+
for (int i = 0; i < count; i++) {
43+
assert(dest[i] == i * 2);
44+
}
45+
46+
// Copying to nullptr should throw.
47+
q.submit([&](handler &cgh) {
48+
cgh.memcpy(nullptr, src, sizeof(float) * count);
49+
});
50+
try {
51+
q.wait_and_throw();
52+
assert(false && "Expected error from copying to nullptr");
53+
} catch (runtime_error e) {
54+
}
55+
}

sycl/test/usm/memset.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
//==---- memset.cpp - USM memset test --------------------------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// RUN: %clangxx -fsycl %s -o %t1.out -lOpenCL
9+
// RUN: %CPU_RUN_PLACEHOLDER %t1.out
10+
11+
#include <CL/sycl.hpp>
12+
13+
using namespace cl::sycl;
14+
15+
static constexpr int count = 100;
16+
17+
int main() {
18+
queue q([](exception_list el) {
19+
for (auto &e : el)
20+
std::rethrow_exception(e);
21+
});
22+
uint32_t *src = (uint32_t*)malloc_shared(sizeof(uint32_t) * count, q.get_device(),
23+
q.get_context());
24+
25+
event init_copy = q.submit([&](handler &cgh) {
26+
cgh.memset(src, 0x15, sizeof(uint32_t) * count);
27+
});
28+
29+
q.submit([&](handler &cgh) {
30+
cgh.depends_on(init_copy);
31+
cgh.single_task<class double_dest>([=]() {
32+
for (int i = 0; i < count; i++)
33+
src[i] *= 2;
34+
});
35+
});
36+
q.wait_and_throw();
37+
38+
for (int i = 0; i < count; i++) {
39+
assert(src[i] == 0x2a2a2a2a);
40+
}
41+
42+
// Filling to nullptr should throw.
43+
q.submit([&](handler &cgh) {
44+
cgh.memset(nullptr, 0, sizeof(uint32_t) * count);
45+
});
46+
try {
47+
q.wait_and_throw();
48+
assert(false && "Expected error from writing to nullptr");
49+
} catch (runtime_error e) {
50+
}
51+
}

0 commit comments

Comments
 (0)