Skip to content

Commit 61ed067

Browse files
authored
SESE: deal with the case where outer loops are moved into the current loop. (#20222)
1 parent 339bc84 commit 61ed067

File tree

3 files changed

+285
-45
lines changed

3 files changed

+285
-45
lines changed

lib/SILOptimizer/Mandatory/TFCanonicalizeCFG.cpp

Lines changed: 126 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -420,11 +420,13 @@ class BasicBlockCloner : public SILClonerWithScopes<BasicBlockCloner> {
420420
"times during SESE cloning.");
421421
}
422422

423-
/// Clone the body of `loop` starting from `startBlock` and nest the cloned
424-
/// fragment into the parent loop. If `startBlock` is the same as the header
425-
/// of `loop`, we clone the entire loop including the back edge. Otherwise,
426-
/// we clone one iteration of the loop body without the back edge.
427-
SILLoop *cloneLoop(SILLoopInfo *LI, SILLoop *loop, SILBasicBlock *startBlock) {
423+
/// Utility to unroll one iteration of the loop or to clone the entire loop.
424+
/// - If `startBlock` is the same as the header of `loop`, we clone the
425+
/// entire loop including the back edge.
426+
/// - Otherwise, we unroll one iteration of the loop body starting from
427+
/// `startBlock' to the latch.
428+
/// The unrolled or cloned version is nested into the parent loop.
429+
SILLoop *cloneOrUnrollLoop(SILLoopInfo *LI, SILLoop *loop, SILBasicBlock *startBlock) {
428430
llvm::DenseMap<SILLoop*, SILLoop*> loopClones;
429431
// This is for convenience as top-level loops have nullptr for parent loop.
430432
loopClones[nullptr] = nullptr;
@@ -514,24 +516,48 @@ class BasicBlockCloner : public SILClonerWithScopes<BasicBlockCloner> {
514516

515517
// A helper class to transform a loop to have a single exit from the header.
516518
class SingleExitLoopTransformer {
517-
public:
518-
SingleExitLoopTransformer(GraphFunctionDeviceInfo *deviceInfo,
519-
SILLoopInfo *LI, DominanceInfo *DI, SILLoop *loop,
520-
PostDominanceInfo *PDI)
521-
: deviceInfo(deviceInfo), DI(DI), PDI(PDI), LI(LI), loop(loop),
522-
header(loop->getHeader()), preheader(loop->getLoopPreheader()),
523-
latch(loop->getLoopLatch()), currentFn(header->getParent()),
524-
oldHeaderNumArgs(header->getNumArguments()), hasUndefsAtPreheader(false) {
525-
assert(preheader && "Canonicalization should have given us one preheader");
526-
assert(latch && "Canonicalization should have given us one latch block");
527-
initialize();
519+
public:
520+
// Convert the given loop into a SESE form. Returns false if the loop was
521+
// already in SESE form. Otherwise, returns true.
522+
static bool doIt(GraphFunctionDeviceInfo *deviceInfo, SILLoopInfo *LI,
523+
DominanceInfo *DI, SILLoop *loop, PostDominanceInfo *PDI) {
524+
SingleExitLoopTransformer transformer(deviceInfo, LI, DI, loop, PDI);
525+
bool loopChanged = transformer.transform();
526+
if (loopChanged) {
527+
// Recalculate dominator information as it is stale now.
528+
DI->recalculate(*transformer.currentFn);
529+
PDI->recalculate(*transformer.currentFn);
530+
}
531+
532+
#ifndef NDEBUG
533+
{
534+
// Verify that the loop is OK after all the transformations.
535+
llvm::DenseSet<const SILLoop *> nestedLoops;
536+
loop->verifyLoopNest(&nestedLoops);
537+
}
538+
#endif
539+
return loopChanged;
540+
}
541+
542+
private:
543+
SingleExitLoopTransformer(GraphFunctionDeviceInfo *deviceInfo,
544+
SILLoopInfo *LI, DominanceInfo *DI, SILLoop *loop,
545+
PostDominanceInfo *PDI)
546+
: deviceInfo(deviceInfo), DI(DI), PDI(PDI), LI(LI), loop(loop),
547+
header(loop->getHeader()), preheader(loop->getLoopPreheader()),
548+
latch(loop->getLoopLatch()), currentFn(header->getParent()),
549+
oldHeaderNumArgs(header->getNumArguments()),
550+
hasUndefsAtPreheader(false) {
551+
assert(preheader &&
552+
"Canonicalization should have given us one preheader");
553+
assert(latch && "Canonicalization should have given us one latch block");
554+
initialize();
528555
}
529556

530557
/// Transforms the loop to ensure it has a single exit from the header.
531558
/// Returns true if the CFG was changed.
532559
bool transform();
533560

534-
private:
535561
// Helper functions
536562

537563
void initialize();
@@ -711,6 +737,24 @@ void SingleExitLoopTransformer::ensureSingleExitBlock() {
711737
<< SILPrintContext(llvm::dbgs()).getID(nearestCommonPD)
712738
<< "\n");
713739

740+
// Compute the set of preheaders of loops that are unrelated to our loop w.r.t
741+
// nesting. This will be used when needing to identify cases where a loop
742+
// from outside is moved into the current loop. e.g.,
743+
// while ... {
744+
// if ... {
745+
// for(...) {...} // This should be nested into the while loop.
746+
// break;
747+
// }
748+
// }
749+
// The unrelated loops are those that are not contained within each other.
750+
SmallPtrSet<SILBasicBlock *, 32> unrelatedPreheaders;
751+
for (auto *otherLoop : *LI) {
752+
if (!otherLoop->contains(loop) && !loop->contains(otherLoop)) {
753+
unrelatedPreheaders.insert(otherLoop->getLoopPreheader());
754+
}
755+
}
756+
757+
714758
// Collect all the blocks from each exiting block up to nearest common PD.
715759
SmallPtrSet<SILBasicBlock *, 32> blocksToBeMoved;
716760
for (SILBasicBlock *exitBlock : exitBlockList) {
@@ -735,6 +779,19 @@ void SingleExitLoopTransformer::ensureSingleExitBlock() {
735779
continue;
736780
}
737781

782+
// Check if `succ` is a preheader of another loop.
783+
SILLoop *succBlockLoop = nullptr;
784+
if (unrelatedPreheaders.count(succ) > 0) {
785+
// We are about a move a loop from outside. Perform canonicalization
786+
// of that loop first.
787+
SILBasicBlock *unrelatedHeader = succ->getSingleSuccessorBlock();
788+
assert(unrelatedHeader &&
789+
"There should be a single successor for a preheader.");
790+
succBlockLoop = LI->getLoopFor(unrelatedHeader);
791+
SingleExitLoopTransformer::doIt(deviceInfo, LI, DI, succBlockLoop,
792+
PDI);
793+
}
794+
738795
if (DI->properlyDominates(header, succ)) {
739796
worklist.insert(succ);
740797
continue;
@@ -752,6 +809,24 @@ void SingleExitLoopTransformer::ensureSingleExitBlock() {
752809

753810
// Clone the block and rewire the edge.
754811
SILBasicBlock *clonedSucc = cloner.initAndCloneBlock(succ);
812+
// If `succ` is a preheader of an unrelated loop, we will have to clone
813+
// the entire loop now so that we can also incrementally update LoopInfo.
814+
if (succBlockLoop) {
815+
SILLoop *clonedLoop = cloner.cloneOrUnrollLoop(
816+
LI, succBlockLoop, succBlockLoop->getHeader());
817+
changeBranchTarget(clonedSucc->getTerminator(), 0,
818+
clonedLoop->getHeader(), /*preserveArgs*/ true);
819+
// Note that all the nodes of `clonedLoop` should be moved into the
820+
// current loop. We do that here itself as an optimization and also
821+
// because the dominator and post-dominator information for the new
822+
// blocks in `clonedLoop` are stale and cannot be relied upon.
823+
for (SILBasicBlock *bb : clonedLoop->getBlocks()) {
824+
blocksToBeMoved.insert(bb);
825+
}
826+
// Add the header to worklist for processing the exit edge.
827+
// (Other successor edges are already processed above.)
828+
worklist.insert(clonedLoop->getHeader());
829+
}
755830
changeBranchTarget(current->getTerminator(), edgeIdx, clonedSucc,
756831
/*preserveArgs*/ true);
757832
worklist.insert(clonedSucc);
@@ -771,24 +846,45 @@ void SingleExitLoopTransformer::ensureSingleExitBlock() {
771846

772847
// Update loop info if this belongs to a parent loop.
773848
SILLoop *outsideBlockLoop = LI->getLoopFor(outsideBlock);
774-
if (outsideBlockLoop != nullptr) {
775-
// FIXME: We don't deal with cases where the nodes being moved in
776-
// belong to another loop yet. e.g.,
849+
if (outsideBlockLoop == nullptr) {
850+
// outsideBlock is not part of any other loop. Simply add it to our loop.
851+
loop->addBasicBlockToLoop(outsideBlock, LI->getBase());
852+
} else {
853+
// We deal with the case where the nodes being moved in
854+
// belong to another loop. e.g.,
777855
// while ... {
778856
// if ... {
779857
// for(...) {...}
780858
// break;
781859
// }
782860
// }
783-
// Check that `loop` is nested within `reachableLoop`.
784-
assert(outsideBlockLoop->contains(loop) &&
785-
"Nodes being moved belong to a non-nested loop.");
786-
// Move the node into our loop.
787-
outsideBlockLoop->removeBlockFromLoop(outsideBlock);
788-
LI->changeLoopFor(outsideBlock, nullptr);
861+
if (outsideBlockLoop->contains(loop)) {
862+
// If our `loop` is nested within `outsideBlockLoop`. Move the node
863+
// from `outsideBlockLoop` into our `loop`.
864+
outsideBlockLoop->removeBlockFromLoop(outsideBlock);
865+
LI->changeLoopFor(outsideBlock, nullptr);
866+
loop->addBasicBlockToLoop(outsideBlock, LI->getBase());
867+
} else {
868+
// We should only nest `outsideBlockLoop` into our `loop` when we
869+
// process the very first node of the `outsideBlockLoop`. Check that we
870+
// have not already nested the `outsideBlockLoop` into our `loop`.
871+
if (!loop->contains(outsideBlockLoop)) {
872+
// Not yet nested, adjust the LoopInfo w.r.t nesting.
873+
if (outsideBlockLoop->getParentLoop() == nullptr) {
874+
// Remove from top-level loops as we are nesting it in `loop`.
875+
LI->removeLoop(llvm::find(*LI, outsideBlockLoop));
876+
}
877+
loop->addChildLoop(outsideBlockLoop);
878+
}
879+
// Add the block to this loop and all its parents.
880+
auto *L = loop;
881+
while (L) {
882+
L->addBlockEntry(outsideBlock);
883+
L = L->getParentLoop();
884+
}
885+
}
789886
// top-level loop is already correct.
790887
}
791-
loop->addBasicBlockToLoop(outsideBlock, LI->getBase());
792888
}
793889
if (cloner.hasCloned()) {
794890
// TODO(https://bugs.swift.org/browse/SR-8336): the transformations here are
@@ -1283,13 +1379,8 @@ void SESERegionBuilder::ensureSingleExitFromLoops() {
12831379
}
12841380
continue;
12851381
}
1286-
SingleExitLoopTransformer transformer(&deviceInfo, &LI, &DI, loop, &PDI);
1287-
bool loopChanged = transformer.transform();
1288-
if (loopChanged) {
1289-
// Recalculate dominator information as it is stale now.
1290-
DI.recalculate(*F);
1291-
PDI.recalculate(*F);
1292-
}
1382+
bool loopChanged =
1383+
SingleExitLoopTransformer::doIt(&deviceInfo, &LI, &DI, loop, &PDI);
12931384
changed |= loopChanged;
12941385
}
12951386
if (changed) {
@@ -1414,7 +1505,7 @@ void SingleExitLoopTransformer::unrollLoopBodyOnce() {
14141505
}
14151506

14161507
// Clone everything starting from the old header.
1417-
cloner.cloneLoop(LI, loop, header);
1508+
cloner.cloneOrUnrollLoop(LI, loop, header);
14181509

14191510
// Get the clone for old header.
14201511
SILBasicBlock *clonedOldHeader = cloner.remapBasicBlock(header);

test/TensorFlow/sese_loop_canonicalization.sil

Lines changed: 107 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %target-sil-opt -tf-xla-cfg-canonicalize -tf-ensure-single-loop-exit -assume-parsing-unqualified-ownership-sil %s -o /dev/null | %FileCheck %s
1+
// RUN: %target-sil-opt -tf-dump-intermediates -tf-xla-cfg-canonicalize -tf-ensure-single-loop-exit -assume-parsing-unqualified-ownership-sil %s -o /dev/null | %FileCheck %s
22

33
import Builtin
44
import Swift
@@ -294,17 +294,17 @@ bb3 (%9 : $Builtin.Int32):
294294
// }
295295
// The CFG is shown below:
296296
//
297-
// [0]-----+
297+
// [0]-------+
298298
// | \
299299
// [1] [2]
300-
// | |
301-
// | +--- [4]
302-
// | | / \
303-
// | | [6] [5]
304-
// | +/ \ |
305-
// [3]----[8] |
306-
// | |
307-
// +-->[9]<------+
300+
// | |
301+
// | +--->[4]
302+
// | [7] / \
303+
// | | [6] [5]
304+
// V +-+ \ |
305+
// [3]<-----[8] |
306+
// | |
307+
// +-->[9]<-----+
308308
//
309309
// This is an example where the nodes that lead to the common post dominator of all exit nodes
310310
// is also reachable from nodes outside of the loop. The key issue is that bb9 is the common
@@ -374,3 +374,100 @@ bb8:
374374
bb9(%18 : $Builtin.Int32):
375375
return %18 : $Builtin.Int32
376376
}
377+
378+
// The following example is similar to loopThatRequiresNodeCloning, where the
379+
// nodes that lead to the common post dominator of all exit nodes is also
380+
// reachable from nodes outside of the loop. The only difference in this example
381+
// is that some of these nodes also belong to a loop. Therefore, it is required
382+
// for us to update the LoopInfo in addition to cloning. The loop in question
383+
// consists of the single node bb3. The CFG is shown below:
384+
//
385+
// [0]-------+
386+
// | \
387+
// [1] [2]
388+
// | |
389+
// | +--->[4]
390+
// | [7] / \
391+
// | | [6] [5]
392+
// V +--+ \ |
393+
// +-->[3]<-----[8] |
394+
// | / | |
395+
// + +-->[9]<-----+
396+
//
397+
// CHECK-LABEL:--- XLA CFG Loops Before Canonicalize: $loopThatRequiresNodeAndLoopCloning
398+
//----The following loop will be cloned in the second loop.
399+
// CHECK:Loop at depth 1 containing: {{bb[0-9]+}}<header><latch><exiting>
400+
// CHECK:Loop at depth 1 containing: {{bb[0-9]+}}<header><exiting>,{{bb[0-9]+}}<exiting>,{{bb[0-9]+}}<latch>
401+
402+
// CHECK-LABEL:--- XLA CFG Loops After Canonicalize: $loopThatRequiresNodeAndLoopCloning
403+
// CHECK:Loop at depth 1 containing: [[L1HDR:bb[0-9]+]]<header><exiting>,[[L1LATCH:bb[0-9]+]]<latch>
404+
// CHECK:Loop at depth 1 containing: [[L2HDR:bb[0-9]+]]<header><exiting>,{{.*}},[[L2LATCH:bb[0-9]+]]<latch>
405+
//----Following is the clone of the first loop.
406+
// CHECK: Loop at depth 2 containing: [[L3HDR:bb[0-9]+]]<header><exiting>,[[L3LATCH:bb[0-9]+]]<latch>
407+
408+
// CHECK-LABEL:--- XLA CFG Canonicalize: $loopThatRequiresNodeAndLoopCloning
409+
// CHECK:[sequence
410+
// CHECK: {condition Header: {{bb[0-9]+}}
411+
// CHECK: [sequence
412+
// CHECK: <while Preheader: {{bb[0-9]+}}, Header: [[L1HDR]], exit: {{bb[0-9]+}}
413+
// CHECK: block [[L1LATCH]]>
414+
// CHECK: block {{bb[0-9]+}}]
415+
// CHECK: [sequence
416+
// CHECK: <while Preheader: {{bb[0-9]+}}, Header: [[L2HDR]], exit: {{bb[0-9]+}}
417+
// CHECK: [sequence
418+
// CHECK: {condition Header: {{bb[0-9]+}}
419+
// CHECK: block {{bb[0-9]+}}
420+
// CHECK: {condition Header: {{bb[0-9]+}}
421+
// CHECK: [sequence
422+
// CHECK: <while Preheader: {{bb[0-9]+}}, Header: [[L3HDR]], exit: {{bb[0-9]+}}
423+
// CHECK: block [[L3LATCH]]>
424+
// CHECK: block {{bb[0-9]+}}]
425+
// CHECK: block {{bb[0-9]+}}}}
426+
// CHECK: block [[L2LATCH]]]>
427+
// CHECK: block {{bb[0-9]+}}]}
428+
// CHECK: block {{bb[0-9]+}}]
429+
430+
sil @$loopThatRequiresNodeAndLoopCloning : $@convention(thin) (Builtin.Int32, Builtin.Int32) -> Builtin.Int32 {
431+
bb0(%0 : $Builtin.Int32, %1 : $Builtin.Int32):
432+
%2 = integer_literal $Builtin.Int32, 1
433+
%3 = integer_literal $Builtin.Int32, 100
434+
%4 = builtin "sadd_with_overflow_Int32"(%0 : $Builtin.Int32, %2 : $Builtin.Int32) : $Builtin.Int32
435+
%5 = builtin "cmp_slt_Int32"(%4 : $Builtin.Int32, %3 : $Builtin.Int32) : $Builtin.Int1
436+
cond_br %5, bb1, bb2
437+
438+
bb1:
439+
// First arg is sum, the second arg is loop counter.
440+
br bb3(%4 : $Builtin.Int32, %0 : $Builtin.Int32)
441+
442+
bb2:
443+
br bb4(%4 : $Builtin.Int32, %0 : $Builtin.Int32)
444+
445+
bb3(%6 : $Builtin.Int32, %7 : $Builtin.Int32):
446+
%8 = builtin "ssub_with_overflow_Int32"(%6 : $Builtin.Int32, %7 : $Builtin.Int32) : $Builtin.Int32
447+
%100 = builtin "cmp_slt_Int32"(%8 : $Builtin.Int32, %7 : $Builtin.Int32) : $Builtin.Int1
448+
cond_br %100, bb3(%7: $Builtin.Int32, %8: $Builtin.Int32), bb9(%8 : $Builtin.Int32)
449+
450+
bb4(%9 : $Builtin.Int32, %10 : $Builtin.Int32):
451+
%11 = builtin "sadd_with_overflow_Int32"(%10 : $Builtin.Int32, %2 : $Builtin.Int32) : $Builtin.Int32
452+
%12 = builtin "sadd_with_overflow_Int32"(%9 : $Builtin.Int32, %10 : $Builtin.Int32) : $Builtin.Int32
453+
%13 = builtin "cmp_slt_Int32"(%11 : $Builtin.Int32, %1 : $Builtin.Int32) : $Builtin.Int1
454+
cond_br %13, bb5, bb6
455+
456+
bb5:
457+
br bb9(%12 : $Builtin.Int32)
458+
459+
bb6:
460+
%14 = builtin "sadd_with_overflow_Int32"(%11 : $Builtin.Int32, %2 : $Builtin.Int32) : $Builtin.Int32
461+
%15 = builtin "sadd_with_overflow_Int32"(%12 : $Builtin.Int32, %14 : $Builtin.Int32) : $Builtin.Int32
462+
%16 = builtin "cmp_slt_Int32"(%14 : $Builtin.Int32, %1 : $Builtin.Int32) : $Builtin.Int1
463+
cond_br %16, bb8, bb7
464+
465+
bb7:
466+
br bb4(%15 : $Builtin.Int32, %14 : $Builtin.Int32)
467+
468+
bb8:
469+
br bb3(%15 : $Builtin.Int32, %14 : $Builtin.Int32)
470+
471+
bb9(%18 : $Builtin.Int32):
472+
return %18 : $Builtin.Int32
473+
}

0 commit comments

Comments
 (0)