Skip to content

Commit 56b6fa7

Browse files
author
Sergey Kanaev
committed
[SYCL] Create an allocation on device for host-interop-task
Signed-off-by: Sergey Kanaev <[email protected]>
1 parent 285cb5a commit 56b6fa7

File tree

1 file changed

+30
-7
lines changed

1 file changed

+30
-7
lines changed

sycl/source/detail/scheduler/graph_builder.cpp

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,16 @@ Scheduler::GraphBuilder::addEmptyCmd(Command *Cmd, const std::vector<T *> &Reqs,
679679
return EmptyCmd;
680680
}
681681

682+
static bool isInteropHostTask(const std::unique_ptr<ExecCGCommand> &Cmd) {
683+
if (Cmd->getCG().getType() != CG::CGTYPE::CODEPLAY_HOST_TASK)
684+
return false;
685+
686+
const detail::CGHostTask &HT =
687+
static_cast<detail::CGHostTask &>(Cmd->getCG());
688+
689+
return HT.MHostTask->isInteropTask();
690+
}
691+
682692
Command *
683693
Scheduler::GraphBuilder::addCG(std::unique_ptr<detail::CG> CommandGroup,
684694
QueueImplPtr Queue) {
@@ -695,13 +705,27 @@ Scheduler::GraphBuilder::addCG(std::unique_ptr<detail::CG> CommandGroup,
695705
printGraphAsDot("before_addCG");
696706

697707
for (Requirement *Req : Reqs) {
698-
MemObjRecord *Record = getOrInsertMemObjRecord(Queue, Req);
699-
markModifiedIfWrite(Record, Req);
708+
MemObjRecord *Record = nullptr;
709+
AllocaCommandBase *AllocaCmd = nullptr;
710+
711+
bool isSameCtx = false;
712+
713+
{
714+
const QueueImplPtr &QueueForAlloca = isInteropHostTask(NewCmd) ?
715+
static_cast<detail::CGHostTask &>(NewCmd->getCG()).MQueue : Queue;
716+
717+
Record = getOrInsertMemObjRecord(QueueForAlloca, Req);
718+
markModifiedIfWrite(Record, Req);
719+
720+
AllocaCmd = getOrCreateAllocaForReq(Record, Req, QueueForAlloca);
721+
722+
isSameCtx =
723+
sameCtx(QueueForAlloca->getContextImplPtr(), Record->MCurContext);
724+
}
700725

701-
AllocaCommandBase *AllocaCmd = getOrCreateAllocaForReq(Record, Req, Queue);
702726
// If there is alloca command we need to check if the latest memory is in
703727
// required context.
704-
if (sameCtx(Queue->getContextImplPtr(), Record->MCurContext)) {
728+
if (isSameCtx) {
705729
// If the memory is already in the required host context, check if the
706730
// required access mode is valid, remap if not.
707731
if (Record->MCurContext->is_host() &&
@@ -713,12 +737,11 @@ Scheduler::GraphBuilder::addCG(std::unique_ptr<detail::CG> CommandGroup,
713737
bool NeedMemMoveToHost = false;
714738
auto MemMoveTargetQueue = Queue;
715739

716-
if (CGType == CG::CGTYPE::CODEPLAY_HOST_TASK) {
740+
if (isInteropHostTask(NewCmd)) {
717741
const detail::CGHostTask &HT =
718742
static_cast<detail::CGHostTask &>(NewCmd->getCG());
719743

720-
if (HT.MHostTask->isInteropTask() && !HT.MQueue->is_host() &&
721-
!Record->MCurContext->is_host()) {
744+
if (HT.MQueue->getContextImplPtr() != Record->MCurContext) {
722745
NeedMemMoveToHost = true;
723746
MemMoveTargetQueue = HT.MQueue;
724747
}

0 commit comments

Comments
 (0)