@@ -1443,13 +1443,36 @@ Scheduler::GraphBuilder::completeFusion(QueueImplPtr Queue,
1443
1443
return LastEvent;
1444
1444
}
1445
1445
1446
+ // Inherit all event dependencies from the input commands in the fusion list.
1447
+ std::vector<EventImplPtr> FusedEventDeps;
1448
+ for (auto *Cmd : CmdList) {
1449
+ FusedEventDeps.insert (FusedEventDeps.end (),
1450
+ Cmd->getPreparedDepsEvents ().begin (),
1451
+ Cmd->getPreparedDepsEvents ().end ());
1452
+ FusedEventDeps.insert (FusedEventDeps.end (),
1453
+ Cmd->getPreparedHostDepsEvents ().begin (),
1454
+ Cmd->getPreparedHostDepsEvents ().end ());
1455
+ }
1456
+
1446
1457
// Remove internal explicit dependencies, i.e., explicit dependencies from one
1447
1458
// kernel in the fusion list to another kernel also in the fusion list.
1448
- auto &FusedEventDeps = FusedCG->MEvents ;
1449
1459
FusedEventDeps.erase (
1450
- std::remove_if (
1451
- FusedEventDeps.begin (), FusedEventDeps.end (),
1452
- [&](EventImplPtr &E) { return E->getCommand () == PlaceholderCmd; }),
1460
+ std::remove_if (FusedEventDeps.begin (), FusedEventDeps.end (),
1461
+ [&](EventImplPtr &E) {
1462
+ if (E->getCommand () == PlaceholderCmd) {
1463
+ return true ;
1464
+ }
1465
+ if (E->getCommand () &&
1466
+ static_cast <Command *>(E->getCommand ())->getType () ==
1467
+ Command::RUN_CG) {
1468
+ auto *RunCGCmd =
1469
+ static_cast <ExecCGCommand *>(E->getCommand ());
1470
+ if (RunCGCmd->MFusionCmd == PlaceholderCmd) {
1471
+ return true ;
1472
+ }
1473
+ }
1474
+ return false ;
1475
+ }),
1453
1476
FusedEventDeps.end ());
1454
1477
1455
1478
auto FusedKernelCmd =
@@ -1466,8 +1489,8 @@ Scheduler::GraphBuilder::completeFusion(QueueImplPtr Queue,
1466
1489
}
1467
1490
1468
1491
createGraphForCommand (FusedKernelCmd.get (), FusedKernelCmd->getCG (), false ,
1469
- FusedKernelCmd->getCG ().MRequirements ,
1470
- FusedKernelCmd-> getCG (). MEvents , Queue, ToEnqueue);
1492
+ FusedKernelCmd->getCG ().MRequirements , FusedEventDeps,
1493
+ Queue, ToEnqueue);
1471
1494
1472
1495
ToEnqueue.push_back (FusedKernelCmd.get ());
1473
1496
0 commit comments