@@ -679,6 +679,16 @@ Scheduler::GraphBuilder::addEmptyCmd(Command *Cmd, const std::vector<T *> &Reqs,
679
679
return EmptyCmd;
680
680
}
681
681
682
+ static bool isInteropHostTask (const std::unique_ptr<ExecCGCommand> &Cmd) {
683
+ if (Cmd->getCG ().getType () != CG::CGTYPE::CODEPLAY_HOST_TASK)
684
+ return false ;
685
+
686
+ const detail::CGHostTask &HT =
687
+ static_cast <detail::CGHostTask &>(Cmd->getCG ());
688
+
689
+ return HT.MHostTask ->isInteropTask ();
690
+ }
691
+
682
692
Command *
683
693
Scheduler::GraphBuilder::addCG (std::unique_ptr<detail::CG> CommandGroup,
684
694
QueueImplPtr Queue) {
@@ -695,13 +705,27 @@ Scheduler::GraphBuilder::addCG(std::unique_ptr<detail::CG> CommandGroup,
695
705
printGraphAsDot (" before_addCG" );
696
706
697
707
for (Requirement *Req : Reqs) {
698
- MemObjRecord *Record = getOrInsertMemObjRecord (Queue, Req);
699
- markModifiedIfWrite (Record, Req);
708
+ MemObjRecord *Record = nullptr ;
709
+ AllocaCommandBase *AllocaCmd = nullptr ;
710
+
711
+ bool isSameCtx = false ;
712
+
713
+ {
714
+ const QueueImplPtr &QueueForAlloca = isInteropHostTask (NewCmd) ?
715
+ static_cast <detail::CGHostTask &>(NewCmd->getCG ()).MQueue : Queue;
716
+
717
+ Record = getOrInsertMemObjRecord (QueueForAlloca, Req);
718
+ markModifiedIfWrite (Record, Req);
719
+
720
+ AllocaCmd = getOrCreateAllocaForReq (Record, Req, QueueForAlloca);
721
+
722
+ isSameCtx =
723
+ sameCtx (QueueForAlloca->getContextImplPtr (), Record->MCurContext );
724
+ }
700
725
701
- AllocaCommandBase *AllocaCmd = getOrCreateAllocaForReq (Record, Req, Queue);
702
726
// If there is alloca command we need to check if the latest memory is in
703
727
// required context.
704
- if (sameCtx (Queue-> getContextImplPtr (), Record-> MCurContext ) ) {
728
+ if (isSameCtx ) {
705
729
// If the memory is already in the required host context, check if the
706
730
// required access mode is valid, remap if not.
707
731
if (Record->MCurContext ->is_host () &&
@@ -713,12 +737,11 @@ Scheduler::GraphBuilder::addCG(std::unique_ptr<detail::CG> CommandGroup,
713
737
bool NeedMemMoveToHost = false ;
714
738
auto MemMoveTargetQueue = Queue;
715
739
716
- if (CGType == CG::CGTYPE::CODEPLAY_HOST_TASK ) {
740
+ if (isInteropHostTask (NewCmd) ) {
717
741
const detail::CGHostTask &HT =
718
742
static_cast <detail::CGHostTask &>(NewCmd->getCG ());
719
743
720
- if (HT.MHostTask ->isInteropTask () && !HT.MQueue ->is_host () &&
721
- !Record->MCurContext ->is_host ()) {
744
+ if (HT.MQueue ->getContextImplPtr () != Record->MCurContext ) {
722
745
NeedMemMoveToHost = true ;
723
746
MemMoveTargetQueue = HT.MQueue ;
724
747
}
0 commit comments