@@ -985,22 +985,6 @@ Scheduler::GraphBuilder::addCG(std::unique_ptr<detail::CG> CommandGroup,
985
985
createGraphForCommand (NewCmd.get (), NewCmd->getCG (),
986
986
isInteropHostTask (NewCmd.get ()), Reqs, Events, Queue,
987
987
FusionCmd->auxiliaryCommands ());
988
- // We need to check the commands that this kernel depends on for any other
989
- // commands that have been submitted to another queue which is also in
990
- // fusion mode. If we detect such another command, we cancel fusion for that
991
- // other queue to avoid circular dependencies.
992
- // Handle requirements on any commands part of another active fusion.
993
- for (auto &Dep : NewCmd->MDeps ) {
994
- auto *DepCmd = Dep.MDepCommand ;
995
- if (!DepCmd) {
996
- continue ;
997
- }
998
- if (DepCmd->getQueue () != Queue && isPartOfActiveFusion (DepCmd)) {
999
- printFusionWarning (" Aborting fusion because of requirement from a "
1000
- " different fusion process" );
1001
- cancelFusion (DepCmd->getQueue (), ToEnqueue);
1002
- }
1003
- }
1004
988
1005
989
// Set the fusion command, so we recognize when another command depends on a
1006
990
// kernel in the fusion list.
@@ -1431,6 +1415,67 @@ void Scheduler::GraphBuilder::cancelFusion(QueueImplPtr Queue,
1431
1415
PlaceholderCmd->setFusionStatus (KernelFusionCommand::FusionStatus::CANCELLED);
1432
1416
}
1433
1417
1418
+ static bool isPartOfFusion (Command *Cmd, KernelFusionCommand *Fusion) {
1419
+ if (Cmd->getType () == Command::RUN_CG) {
1420
+ return static_cast <ExecCGCommand *>(Cmd)->MFusionCmd == Fusion;
1421
+ }
1422
+ return false ;
1423
+ }
1424
+
1425
+ static bool checkForCircularDependency (Command *, bool , KernelFusionCommand *);
1426
+
1427
+ static bool createsCircularDependency (Command *Cmd, bool PredPartOfFusion,
1428
+ KernelFusionCommand *Fusion) {
1429
+ if (isPartOfFusion (Cmd, Fusion)) {
1430
+ if (PredPartOfFusion) {
1431
+ // If this is part of the fusion and the predecessor also was, we can stop
1432
+ // the traversal here. A direct dependency between two kernels in the same
1433
+ // fusion will never form a cyclic dependency and by iterating over all
1434
+ // commands in a fusion, we will detect any cycles originating from the
1435
+ // current command.
1436
+ return false ;
1437
+ } else {
1438
+ // If the predecessor was not part of the fusion, but the current command
1439
+ // is, we have found a potential cycle in the dependency graph.
1440
+ return true ;
1441
+ }
1442
+ }
1443
+ return checkForCircularDependency (Cmd, false , Fusion);
1444
+ }
1445
+
1446
+ static bool checkForCircularDependency (Command *Cmd, bool IsPartOfFusion,
1447
+ KernelFusionCommand *Fusion) {
1448
+ // Check the requirement dependencies.
1449
+ for (auto &Dep : Cmd->MDeps ) {
1450
+ auto *DepCmd = Dep.MDepCommand ;
1451
+ if (!DepCmd) {
1452
+ continue ;
1453
+ }
1454
+ if (createsCircularDependency (DepCmd, IsPartOfFusion, Fusion)) {
1455
+ return true ;
1456
+ }
1457
+ }
1458
+ for (auto &Ev : Cmd->getPreparedDepsEvents ()) {
1459
+ auto *EvDepCmd = static_cast <Command *>(Ev->getCommand ());
1460
+ if (!EvDepCmd) {
1461
+ continue ;
1462
+ }
1463
+ if (createsCircularDependency (EvDepCmd, IsPartOfFusion, Fusion)) {
1464
+ return true ;
1465
+ }
1466
+ }
1467
+ for (auto &Ev : Cmd->getPreparedHostDepsEvents ()) {
1468
+ auto *EvDepCmd = static_cast <Command *>(Ev->getCommand ());
1469
+ if (!EvDepCmd) {
1470
+ continue ;
1471
+ }
1472
+ if (createsCircularDependency (EvDepCmd, IsPartOfFusion, Fusion)) {
1473
+ return true ;
1474
+ }
1475
+ }
1476
+ return false ;
1477
+ }
1478
+
1434
1479
EventImplPtr
1435
1480
Scheduler::GraphBuilder::completeFusion (QueueImplPtr Queue,
1436
1481
std::vector<Command *> &ToEnqueue,
@@ -1451,8 +1496,29 @@ Scheduler::GraphBuilder::completeFusion(QueueImplPtr Queue,
1451
1496
auto *PlaceholderCmd = FusionList->second .get ();
1452
1497
auto &CmdList = PlaceholderCmd->getFusionList ();
1453
1498
1454
- // TODO: The logic to invoke the JIT compiler to create a fused kernel from
1455
- // the list will be added in a later PR.
1499
+ // We need to check if fusing the kernel would create a circular dependency. A
1500
+ // circular dependency would arise, if a kernel in the fusion list
1501
+ // *indirectly* depends on another kernel in the fusion list. Here, indirectly
1502
+ // means, that the dependency is created through a third command not part of
1503
+ // the fusion, on which this kernel depends and which in turn depends on
1504
+ // another kernel in fusion list.
1505
+ bool CreatesCircularDep = false ;
1506
+ for (auto *Cmd : CmdList) {
1507
+ if (checkForCircularDependency (Cmd, true , PlaceholderCmd)) {
1508
+ CreatesCircularDep = true ;
1509
+ break ;
1510
+ }
1511
+ }
1512
+ if (CreatesCircularDep) {
1513
+ // If fusing would create a fused kernel, cancel the fusion.
1514
+ printFusionWarning (
1515
+ " Aborting fusion because it would create a circular dependency" );
1516
+ auto LastEvent = PlaceholderCmd->getEvent ();
1517
+ this ->cancelFusion (Queue, ToEnqueue);
1518
+ return LastEvent;
1519
+ }
1520
+
1521
+ // Call the JIT compiler to generate a new fused kernel.
1456
1522
auto FusedCG = detail::jit_compiler::get_instance ().fuseKernels (
1457
1523
Queue, CmdList, PropList);
1458
1524
0 commit comments