Skip to content

Commit 5bdbb12

Browse files
committed
[SYCL][Fusion] Inherit all depencies from fusion input
Signed-off-by: Lukas Sommer <[email protected]>
1 parent a6ccf6a commit 5bdbb12

File tree

1 file changed

+29
-6
lines changed

1 file changed

+29
-6
lines changed

sycl/source/detail/scheduler/graph_builder.cpp

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1443,13 +1443,36 @@ Scheduler::GraphBuilder::completeFusion(QueueImplPtr Queue,
14431443
return LastEvent;
14441444
}
14451445

1446+
// Inherit all event dependencies from the input commands in the fusion list.
1447+
std::vector<EventImplPtr> FusedEventDeps;
1448+
for (auto *Cmd : CmdList) {
1449+
FusedEventDeps.insert(FusedEventDeps.end(),
1450+
Cmd->getPreparedDepsEvents().begin(),
1451+
Cmd->getPreparedDepsEvents().end());
1452+
FusedEventDeps.insert(FusedEventDeps.end(),
1453+
Cmd->getPreparedHostDepsEvents().begin(),
1454+
Cmd->getPreparedHostDepsEvents().end());
1455+
}
1456+
14461457
// Remove internal explicit dependencies, i.e., explicit dependencies from one
14471458
// kernel in the fusion list to another kernel also in the fusion list.
1448-
auto &FusedEventDeps = FusedCG->MEvents;
14491459
FusedEventDeps.erase(
1450-
std::remove_if(
1451-
FusedEventDeps.begin(), FusedEventDeps.end(),
1452-
[&](EventImplPtr &E) { return E->getCommand() == PlaceholderCmd; }),
1460+
std::remove_if(FusedEventDeps.begin(), FusedEventDeps.end(),
1461+
[&](EventImplPtr &E) {
1462+
if (E->getCommand() == PlaceholderCmd) {
1463+
return true;
1464+
}
1465+
if (E->getCommand() &&
1466+
static_cast<Command *>(E->getCommand())->getType() ==
1467+
Command::RUN_CG) {
1468+
auto *RunCGCmd =
1469+
static_cast<ExecCGCommand *>(E->getCommand());
1470+
if (RunCGCmd->MFusionCmd == PlaceholderCmd) {
1471+
return true;
1472+
}
1473+
}
1474+
return false;
1475+
}),
14531476
FusedEventDeps.end());
14541477

14551478
auto FusedKernelCmd =
@@ -1466,8 +1489,8 @@ Scheduler::GraphBuilder::completeFusion(QueueImplPtr Queue,
14661489
}
14671490

14681491
createGraphForCommand(FusedKernelCmd.get(), FusedKernelCmd->getCG(), false,
1469-
FusedKernelCmd->getCG().MRequirements,
1470-
FusedKernelCmd->getCG().MEvents, Queue, ToEnqueue);
1492+
FusedKernelCmd->getCG().MRequirements, FusedEventDeps,
1493+
Queue, ToEnqueue);
14711494

14721495
ToEnqueue.push_back(FusedKernelCmd.get());
14731496

0 commit comments

Comments
 (0)