Skip to content

Commit c514d25

Browse files
romanovvladbader
authored andcommitted
[SYCL][Scheduler] Refactor data transfer scheduler commands
The patch makes transfer commands such as map and unmap more flexible. Now they can be used not only for situations when we have host accessor which we need to initialize with host ptr, but for doing map/unmap in cases when we can/must avoid copying which would require additional memory allocation on the host. Signed-off-by: Vlad Romanov <[email protected]>
1 parent 3546a78 commit c514d25

File tree

3 files changed

+47
-52
lines changed

3 files changed

+47
-52
lines changed

sycl/include/CL/sycl/detail/scheduler/commands.hpp

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -245,13 +245,12 @@ class AllocaSubBufCommand : public AllocaCommandBase {
245245

246246
class MapMemObject : public Command {
247247
public:
248-
MapMemObject(Requirement SrcReq, AllocaCommandBase *SrcAlloca,
249-
Requirement *DstAcc, QueueImplPtr Queue);
248+
MapMemObject(AllocaCommandBase *SrcAlloca, Requirement *Req, void **DstPtr,
249+
QueueImplPtr Queue);
250250

251-
Requirement MSrcReq;
252251
AllocaCommandBase *MSrcAlloca = nullptr;
253-
Requirement *MDstAcc = nullptr;
254-
Requirement MDstReq;
252+
void **MDstPtr = nullptr;
253+
Requirement MReq;
255254

256255
void printDot(std::ostream &Stream) const override;
257256

@@ -261,18 +260,17 @@ class MapMemObject : public Command {
261260

262261
class UnMapMemObject : public Command {
263262
public:
264-
UnMapMemObject(Requirement SrcReq, AllocaCommandBase *SrcAlloca,
265-
Requirement *DstAcc, QueueImplPtr Queue,
266-
bool UseExclusiveQueue = false);
263+
UnMapMemObject(AllocaCommandBase *DstAlloca, Requirement *Req, void **SrcPtr,
264+
QueueImplPtr Queue, bool UseExclusiveQueue = false);
267265

268266
void printDot(std::ostream &Stream) const override;
269267

270268
private:
271269
cl_int enqueueImp() override;
272270

273-
Requirement MSrcReq;
274-
AllocaCommandBase *MSrcAlloca = nullptr;
275-
Requirement *MDstAcc = nullptr;
271+
AllocaCommandBase *MDstAlloca = nullptr;
272+
Requirement MReq;
273+
void **MSrcPtr = nullptr;
276274
};
277275

278276
// The command enqueues memory copy between two instances of memory object.
@@ -304,14 +302,14 @@ class MemCpyCommand : public Command {
304302
class MemCpyCommandHost : public Command {
305303
public:
306304
MemCpyCommandHost(Requirement SrcReq, AllocaCommandBase *SrcAlloca,
307-
Requirement *DstAcc, QueueImplPtr SrcQueue,
305+
Requirement DstReq, void **DstPtr, QueueImplPtr SrcQueue,
308306
QueueImplPtr DstQueue);
309307

310308
QueueImplPtr MSrcQueue;
311309
Requirement MSrcReq;
312310
AllocaCommandBase *MSrcAlloca = nullptr;
313311
Requirement MDstReq;
314-
Requirement *MDstAcc = nullptr;
312+
void **MDstPtr = nullptr;
315313

316314
void printDot(std::ostream &Stream) const override;
317315

@@ -341,21 +339,20 @@ class ExecCGCommand : public Command {
341339

342340
class UpdateHostRequirementCommand : public Command {
343341
public:
344-
UpdateHostRequirementCommand(QueueImplPtr Queue, Requirement *Req,
345-
AllocaCommandBase *AllocaForReq)
342+
UpdateHostRequirementCommand(QueueImplPtr Queue, AllocaCommandBase *AllocaCmd,
343+
Requirement *Req, void **DstPtr)
346344
: Command(CommandType::UPDATE_REQUIREMENT, std::move(Queue)),
347-
MReqToUpdate(Req), MAllocaForReq(AllocaForReq),
348-
MStoredRequirement(*Req) {}
345+
MDstPtr(DstPtr), MAllocaCmd(AllocaCmd), MReq(*Req) {}
349346

350-
Requirement *getStoredRequirement() { return &MStoredRequirement; }
347+
Requirement *getStoredRequirement() { return &MReq; }
351348

352349
private:
353350
cl_int enqueueImp() override;
354351
void printDot(std::ostream &Stream) const override;
355352

356-
Requirement *MReqToUpdate = nullptr;
357-
AllocaCommandBase *MAllocaForReq = nullptr;
358-
Requirement MStoredRequirement;
353+
void **MDstPtr = nullptr;
354+
AllocaCommandBase *MAllocaCmd = nullptr;
355+
Requirement MReq;
359356
};
360357

361358
} // namespace detail

sycl/source/detail/scheduler/commands.cpp

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -274,24 +274,22 @@ void ReleaseCommand::printDot(std::ostream &Stream) const {
274274
}
275275
}
276276

277-
MapMemObject::MapMemObject(Requirement SrcReq, AllocaCommandBase *SrcAlloca,
278-
Requirement *DstAcc, QueueImplPtr Queue)
277+
MapMemObject::MapMemObject(AllocaCommandBase *SrcAlloca, Requirement *Req,
278+
void **DstPtr, QueueImplPtr Queue)
279279
: Command(CommandType::MAP_MEM_OBJ, std::move(Queue)),
280-
MSrcReq(std::move(SrcReq)), MSrcAlloca(SrcAlloca), MDstAcc(DstAcc),
281-
MDstReq(*DstAcc) {}
280+
MSrcAlloca(SrcAlloca), MDstPtr(DstPtr), MReq(*Req) {}
282281

283282
cl_int MapMemObject::enqueueImp() {
284283
std::vector<RT::PiEvent> RawEvents =
285284
Command::prepareEvents(detail::getSyclObjImpl(MQueue->get_context()));
286-
assert(MDstReq.MDims == 1);
285+
assert(MReq.MDims == 1);
287286

288287
RT::PiEvent &Event = MEvent->getHandleRef();
289288
void *MappedPtr = MemoryManager::map(
290289
MSrcAlloca->getSYCLMemObj(), MSrcAlloca->getMemAllocation(), MQueue,
291-
MDstReq.MAccessMode, MDstReq.MDims, MDstReq.MMemoryRange,
292-
MDstReq.MAccessRange, MDstReq.MOffset, MDstReq.MElemSize,
293-
std::move(RawEvents), Event);
294-
MDstAcc->MData = MappedPtr;
290+
MReq.MAccessMode, MReq.MDims, MReq.MMemoryRange, MReq.MAccessRange,
291+
MReq.MOffset, MReq.MElemSize, std::move(RawEvents), Event);
292+
*MDstPtr = MappedPtr;
295293
return CL_SUCCESS;
296294
}
297295

@@ -311,19 +309,19 @@ void MapMemObject::printDot(std::ostream &Stream) const {
311309
}
312310
}
313311

314-
UnMapMemObject::UnMapMemObject(Requirement SrcReq, AllocaCommandBase *SrcAlloca,
315-
Requirement *DstAcc, QueueImplPtr Queue,
312+
UnMapMemObject::UnMapMemObject(AllocaCommandBase *DstAlloca, Requirement *Req,
313+
void **SrcPtr, QueueImplPtr Queue,
316314
bool UseExclusiveQueue)
317315
: Command(CommandType::UNMAP_MEM_OBJ, std::move(Queue), UseExclusiveQueue),
318-
MSrcReq(std::move(SrcReq)), MSrcAlloca(SrcAlloca), MDstAcc(DstAcc) {}
316+
MDstAlloca(DstAlloca), MReq(*Req), MSrcPtr(SrcPtr) {}
319317

320318
cl_int UnMapMemObject::enqueueImp() {
321319
std::vector<RT::PiEvent> RawEvents =
322320
Command::prepareEvents(detail::getSyclObjImpl(MQueue->get_context()));
323321

324322
RT::PiEvent &Event = MEvent->getHandleRef();
325-
MemoryManager::unmap(MSrcAlloca->getSYCLMemObj(),
326-
MSrcAlloca->getMemAllocation(), MQueue, MDstAcc->MData,
323+
MemoryManager::unmap(MDstAlloca->getSYCLMemObj(),
324+
MDstAlloca->getMemAllocation(), MQueue, *MSrcPtr,
327325
std::move(RawEvents), MUseExclusiveQueue, Event);
328326
return CL_SUCCESS;
329327
}
@@ -427,9 +425,10 @@ cl_int UpdateHostRequirementCommand::enqueueImp() {
427425
RT::PiEvent &Event = MEvent->getHandleRef();
428426
Command::waitForEvents(MQueue, RawEvents, Event);
429427

430-
assert(MAllocaForReq && "Expected valid alloca command");
431-
assert(MReqToUpdate && "Expected valid requirement");
432-
MReqToUpdate->MData = MAllocaForReq->getMemAllocation();
428+
assert(MAllocaCmd && "Expected valid alloca command");
429+
assert(MAllocaCmd->getMemAllocation() && "Expected valid source pointer");
430+
assert(MDstPtr && "Expected valid target pointer");
431+
*MDstPtr = MAllocaCmd->getMemAllocation();
433432
return CL_SUCCESS;
434433
}
435434

@@ -438,12 +437,11 @@ void UpdateHostRequirementCommand::printDot(std::ostream &Stream) const {
438437

439438
Stream << "ID = " << this << "\n";
440439
Stream << "UPDATE REQ ON " << deviceToString(MQueue->get_device()) << "\\n";
441-
bool IsReqOnBuffer = MStoredRequirement.MSYCLMemObj->getType() ==
442-
SYCLMemObjI::MemObjType::BUFFER;
440+
bool IsReqOnBuffer =
441+
MReq.MSYCLMemObj->getType() == SYCLMemObjI::MemObjType::BUFFER;
443442
Stream << "TYPE: " << (IsReqOnBuffer ? "Buffer" : "Image") << "\\n";
444443
if (IsReqOnBuffer)
445-
Stream << "Is sub buffer: " << std::boolalpha
446-
<< MStoredRequirement.MIsSubBuffer << "\\n";
444+
Stream << "Is sub buffer: " << std::boolalpha << MReq.MIsSubBuffer << "\\n";
447445

448446
Stream << "\"];" << std::endl;
449447

@@ -458,11 +456,12 @@ void UpdateHostRequirementCommand::printDot(std::ostream &Stream) const {
458456

459457
MemCpyCommandHost::MemCpyCommandHost(Requirement SrcReq,
460458
AllocaCommandBase *SrcAlloca,
461-
Requirement *DstAcc, QueueImplPtr SrcQueue,
459+
Requirement DstReq, void **DstPtr,
460+
QueueImplPtr SrcQueue,
462461
QueueImplPtr DstQueue)
463462
: Command(CommandType::COPY_MEMORY, std::move(DstQueue)),
464463
MSrcQueue(SrcQueue), MSrcReq(std::move(SrcReq)), MSrcAlloca(SrcAlloca),
465-
MDstReq(*DstAcc), MDstAcc(DstAcc) {
464+
MDstReq(std::move(DstReq)), MDstPtr(DstPtr) {
466465
if (!MSrcQueue->is_host())
467466
MEvent->setContextImpl(detail::getSyclObjImpl(MSrcQueue->get_context()));
468467
}
@@ -485,7 +484,7 @@ cl_int MemCpyCommandHost::enqueueImp() {
485484
MemoryManager::copy(
486485
MSrcAlloca->getSYCLMemObj(), MSrcAlloca->getMemAllocation(), MSrcQueue,
487486
MSrcReq.MDims, MSrcReq.MMemoryRange, MSrcReq.MAccessRange,
488-
MSrcReq.MOffset, MSrcReq.MElemSize, MDstAcc->MData, MQueue, MDstReq.MDims,
487+
MSrcReq.MOffset, MSrcReq.MElemSize, *MDstPtr, MQueue, MDstReq.MDims,
489488
MDstReq.MMemoryRange, MDstReq.MAccessRange, MDstReq.MOffset,
490489
MDstReq.MElemSize, std::move(RawEvents), MUseExclusiveQueue, Event);
491490
return CL_SUCCESS;

sycl/source/detail/scheduler/graph_builder.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ UpdateHostRequirementCommand *Scheduler::GraphBuilder::insertUpdateHostReqCmd(
157157
findAllocaForReq(Record, Req, Queue->get_context_impl());
158158
assert(AllocaCmd && "There must be alloca for requirement!");
159159
UpdateHostRequirementCommand *UpdateCommand =
160-
new UpdateHostRequirementCommand(Queue, Req, AllocaCmd);
160+
new UpdateHostRequirementCommand(Queue, AllocaCmd, Req, &Req->MData);
161161
// Need copy of requirement because after host accessor destructor call
162162
// dependencies become invalid if requirement is stored by pointer.
163163
Requirement *StoredReq = UpdateCommand->getStoredRequirement();
@@ -243,7 +243,7 @@ Command *Scheduler::GraphBuilder::addCopyBack(Requirement *Req) {
243243
findAllocaForReq(Record, Req, Record->MCurContext);
244244

245245
std::unique_ptr<MemCpyCommandHost> MemCpyCmdUniquePtr(new MemCpyCommandHost(
246-
*SrcAllocaCmd->getAllocationReq(), SrcAllocaCmd, Req,
246+
*SrcAllocaCmd->getAllocationReq(), SrcAllocaCmd, *Req, &Req->MData,
247247
SrcAllocaCmd->getQueue(), std::move(HostQueue)));
248248

249249
if (!MemCpyCmdUniquePtr)
@@ -286,7 +286,6 @@ Command *Scheduler::GraphBuilder::addHostAccessor(Requirement *Req,
286286

287287
AllocaCommandBase *SrcAllocaCmd =
288288
getOrCreateAllocaForReq(Record, Req, SrcQueue);
289-
Requirement *SrcReq = SrcAllocaCmd->getAllocationReq();
290289
if (SrcQueue->is_host()) {
291290
UpdateHostRequirementCommand *UpdateCmd =
292291
insertUpdateHostReqCmd(Record, Req, SrcQueue);
@@ -313,7 +312,7 @@ Command *Scheduler::GraphBuilder::addHostAccessor(Requirement *Req,
313312
Req->MSYCLMemObj->getType() == detail::SYCLMemObjI::MemObjType::BUFFER) {
314313

315314
std::unique_ptr<MapMemObject> MapCmdUniquePtr(
316-
new MapMemObject(*SrcReq, SrcAllocaCmd, Req, SrcQueue));
315+
new MapMemObject(SrcAllocaCmd, Req, &Req->MData, SrcQueue));
317316

318317
/*
319318
[SYCL] Use exclusive queues for blocked commands.
@@ -442,19 +441,19 @@ Command *Scheduler::GraphBuilder::addHostAccessor(Requirement *Req,
442441
*/
443442

444443
std::unique_ptr<UnMapMemObject> UnMapCmdUniquePtr(new UnMapMemObject(
445-
*SrcReq, SrcAllocaCmd, Req, SrcQueue, /*UseExclusiveQueue*/ true));
444+
SrcAllocaCmd, Req, &Req->MData, SrcQueue, /*UseExclusiveQueue*/ true));
446445

447446
if (!MapCmdUniquePtr || !UnMapCmdUniquePtr)
448447
throw runtime_error("Out of host memory");
449448

450449
MapMemObject *MapCmd = MapCmdUniquePtr.release();
451450
for (Command *Dep : Deps) {
452-
MapCmd->addDep(DepDesc{Dep, &MapCmd->MDstReq, SrcAllocaCmd});
451+
MapCmd->addDep(DepDesc{Dep, &MapCmd->MReq, SrcAllocaCmd});
453452
Dep->addUser(MapCmd);
454453
}
455454

456455
Command *UnMapCmd = UnMapCmdUniquePtr.release();
457-
UnMapCmd->addDep(DepDesc{MapCmd, &MapCmd->MDstReq, SrcAllocaCmd});
456+
UnMapCmd->addDep(DepDesc{MapCmd, &MapCmd->MReq, SrcAllocaCmd});
458457
MapCmd->addUser(UnMapCmd);
459458

460459
UpdateLeafs(Deps, Record, Req->MAccessMode);

0 commit comments

Comments
 (0)