Skip to content

Commit c827d51

Browse files
authored
Clone blocks that are reachable from outside the loop. (#19570)
1 parent 7eb1710 commit c827d51

File tree

3 files changed

+221
-46
lines changed

3 files changed

+221
-46
lines changed

lib/SILOptimizer/Mandatory/TFCanonicalizeCFG.cpp

Lines changed: 130 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "swift/SIL/GraphOperationBuilder.h"
3232
#include "swift/SIL/LoopInfo.h"
3333
#include "swift/SIL/SILBuilder.h"
34+
#include "swift/SIL/SILCloner.h"
3435
#include "swift/SIL/SILConstants.h"
3536
#include "swift/SIL/SILUndef.h"
3637
#include "llvm/ADT/iterator_range.h"
@@ -321,6 +322,58 @@ static SILValue createTFIntegerConst(GraphFunctionDeviceInfo &deviceInfo,
321322
return constNode->getResults()[0];
322323
}
323324

325+
namespace {
326+
327+
class BasicBlockCloner : public SILClonerWithScopes<BasicBlockCloner> {
328+
private:
329+
/// The flag to track if this cloner was used to clone any blocks.
330+
bool cloned;
331+
332+
public:
333+
BasicBlockCloner(SILFunction &F)
334+
: SILClonerWithScopes(F), cloned(false) {}
335+
336+
bool hasCloned() const { return cloned; }
337+
338+
/// Return a cloned block.
339+
SILBasicBlock *cloneBlock(SILBasicBlock *bb) {
340+
auto bbIt = BBMap.find(bb);
341+
if (bbIt != BBMap.end())
342+
return bbIt->second;
343+
344+
cloned = true;
345+
346+
SILFunction &F = getBuilder().getFunction();
347+
SILBasicBlock *newBB = F.createBasicBlock();
348+
getBuilder().setInsertionPoint(newBB);
349+
BBMap[bb] = newBB;
350+
// If the basic block has arguments, clone them as well.
351+
for (auto *arg : bb->getArguments()) {
352+
// Create a new argument and copy it into the ValueMap so future
353+
// references use it.
354+
ValueMap[arg] = newBB->createPHIArgument(
355+
arg->getType(), arg->getOwnershipKind(), arg->getDecl());
356+
}
357+
// Clone all the instructions.
358+
for (auto &inst : *bb) {
359+
visit(&inst);
360+
}
361+
return newBB;
362+
}
363+
364+
/// Handle references to basic blocks when cloning.
365+
SILBasicBlock *remapBasicBlock(SILBasicBlock *bb) {
366+
// If the block was not cloned by this cloner, directly reference it.
367+
// Otherwise, use the cloned block.
368+
auto bbIt = BBMap.find(bb);
369+
if (bbIt != BBMap.end())
370+
return bbIt->second;
371+
return bb;
372+
}
373+
};
374+
375+
} // namespace
376+
324377
// A helper class to transform a loop to have a single exit from the header.
325378
class SingleExitLoopTransformer {
326379
public:
@@ -432,6 +485,8 @@ class SingleExitLoopTransformer {
432485
llvm::DenseMap<SILValue, SILValue> escapingValueSubstMap;
433486
/// Similar to escapingValueSubstMap, but arguments of exit blocks.
434487
llvm::DenseMap<SILValue, SILValue> exitArgSubstMap;
488+
// Map from an arg of an exit block to the corresponding newly added header arg.
489+
llvm::DenseMap<SILValue, SILValue> exitArgHeaderArgMap;
435490
};
436491

437492
void SingleExitLoopTransformer::initialize() {
@@ -492,6 +547,8 @@ void SingleExitLoopTransformer::initialize() {
492547
}
493548

494549
void SingleExitLoopTransformer::ensureSingleExitBlock() {
550+
BasicBlockCloner cloner(*currentFn);
551+
495552
// Identify the common post dominator
496553
SILPrintContext printContext(llvm::dbgs());
497554
SmallVector<SILBasicBlock*, 8> exitBlockList;
@@ -528,21 +585,31 @@ void SingleExitLoopTransformer::ensureSingleExitBlock() {
528585
// appropriately and not deal with touching stale memory.
529586
auto succs = current->getSuccessors();
530587
auto *succ = succs[edgeIdx].getBB();
531-
// Skip if (1) already processed, (2) reached common pd, or (3)
532-
// block has in edges from outside the loop. In the last case, we will
533-
// need to clone the blocks.
588+
// Skip if (1) already processed or (2) reached common pd.
534589
if (blocksToBeMoved.count(succ) > 0 || succ == nearestCommonPD) {
535590
continue;
536591
}
537-
if (!DI->properlyDominates(header, succ)) {
538-
// Split this edge so that we don't mess up arguments passed in
539-
// from other predecessors of succ.
540-
if (succ->getNumArguments() > 0) {
541-
splitEdge(current->getTerminator(), edgeIdx, DI, LI);
542-
}
592+
593+
if (DI->properlyDominates(header, succ)) {
594+
worklist.insert(succ);
543595
continue;
544596
}
545-
worklist.insert(succ);
597+
// If `succ` is not dominated by `header`, then `succ` is reachable from
598+
// a node outside of this loop. We might have to clone `succ` in such
599+
// cases.
600+
601+
// Before cloning make sure that header -> succ is *not* backedge of a
602+
// parent loop. This can happen when we have labeled breaks in loops. We
603+
// cannot clone the blocks in such cases. Simply continue. This is still
604+
// OK for our purposes because we will find an equivalent value at the
605+
// header for any value that escapes along this edge.
606+
if (DI->properlyDominates(succ, header)) continue;
607+
608+
// Clone the block and rewire the edge.
609+
SILBasicBlock *clonedSucc = cloner.cloneBlock(succ);
610+
changeBranchTarget(current->getTerminator(), edgeIdx, clonedSucc,
611+
/*preserveArgs*/ true);
612+
worklist.insert(clonedSucc);
546613
}
547614
}
548615
}
@@ -578,6 +645,12 @@ void SingleExitLoopTransformer::ensureSingleExitBlock() {
578645
}
579646
loop->addBasicBlockToLoop(outsideBlock, LI->getBase());
580647
}
648+
if (cloner.hasCloned()) {
649+
// TODO(https://bugs.swift.org/browse/SR-8336): the transformations here are
650+
// simple that we should be able to incrementally update the DI & PDI.
651+
DI->recalculate(*currentFn);
652+
PDI->recalculate(*currentFn);
653+
}
581654
}
582655

583656
llvm::DenseMap<SILValue, SILValue>
@@ -699,30 +772,30 @@ SingleExitLoopTransformer::createNewHeader() {
699772
}
700773
header->dropAllArguments();
701774
// Add phi arguments in the new header corresponding to the escaping values.
702-
auto addArgument =
703-
[this, newHeader](SILValue escapingValue) {
704-
SILValue newValue = newHeader->createPHIArgument(
705-
escapingValue->getType(), escapingValue.getOwnershipKind());
706-
// Replace uses *outside* of the loop with the new value.
707-
auto UI = escapingValue->use_begin(), E = escapingValue->use_end();
708-
while (UI != E) {
709-
Operand *use = *UI;
710-
// Increment iterator before we invalidate it
711-
// when we invoke Operand::Set below.
712-
++UI;
713-
if (loop->contains(use->getUser()->getParent())) {
714-
continue;
715-
}
716-
use->set(newValue);
717-
}
718-
};
719-
720775
for (const auto &kv : escapingValueSubstMap) {
721-
addArgument(kv.first);
776+
SILValue escapingValue = kv.first;
777+
SILValue newValue = newHeader->createPHIArgument(
778+
escapingValue->getType(), escapingValue.getOwnershipKind());
779+
// Replace uses *outside* of the loop with the new value.
780+
auto UI = escapingValue->use_begin(), E = escapingValue->use_end();
781+
while (UI != E) {
782+
Operand *use = *UI;
783+
// Increment iterator before we invalidate it
784+
// when we invoke Operand::Set below.
785+
++UI;
786+
if (loop->contains(use->getUser()->getParent())) {
787+
continue;
788+
}
789+
use->set(newValue);
790+
}
722791
}
723792
if (TFNoUndefsInSESE) {
793+
// Add arguments in the new header corresponding to exit block arguments.
724794
for (const auto &kv : exitArgSubstMap) {
725-
addArgument(kv.first);
795+
SILValue arg = kv.first;
796+
SILValue newValue =
797+
newHeader->createPHIArgument(arg->getType(), arg.getOwnershipKind());
798+
exitArgHeaderArgMap[kv.first] = newValue;
726799
}
727800
}
728801
// An integer to identify the exit edge.
@@ -901,15 +974,6 @@ SingleExitLoopTransformer::patchEdges(SILBasicBlock *newHeader,
901974
SILBasicBlock *SingleExitLoopTransformer::createNewExitBlockWithDemux(
902975
const llvm::DenseMap<SILBasicBlock *, intmax_t> &exitIndices,
903976
SILValue exitIndexArg) {
904-
if (TFNoUndefsInSESE) {
905-
// Drop all arguments as we have moved them into the headers arguments.
906-
for (SILBasicBlock *exitBlock : exitBlocks) {
907-
exitBlock->dropAllArguments();
908-
}
909-
}
910-
if (exitBlocks.size() == 1) {
911-
return *exitBlocks.begin();
912-
}
913977
auto createBlockOutsideLoop = [this]() {
914978
SILBasicBlock *newBlock = currentFn->createBasicBlock();
915979
SILLoop *parentLoop = loop->getParentLoop();
@@ -919,11 +983,9 @@ SILBasicBlock *SingleExitLoopTransformer::createNewExitBlockWithDemux(
919983
return newBlock;
920984
};
921985

922-
// Create a new exit block.
923-
// FIXME: We can avoid creating an additional block and instead connect the
924-
// header directly to the demuxBlock created in the loop below. Alternatively,
925-
// we can also use contractUncondBranches in TFParititon.cpp to remove this
926-
// block later.
986+
// Create a new exit block. Strictly, we don't always need this block, but it
987+
// makes it slightly easier to implement the demux blocks. contractUncondEdges
988+
// will merge this block away if appropriate.
927989
SILBasicBlock *newExitBlock = createBlockOutsideLoop();
928990

929991
SILBuilder builder(newExitBlock);
@@ -934,6 +996,17 @@ SILBasicBlock *SingleExitLoopTransformer::createNewExitBlockWithDemux(
934996
SILLocation headerLocation =
935997
getUserSourceLocation(header->getTerminator()->getDebugLocation());
936998

999+
// Find the arguments at the header that were added for the exit arguments
1000+
// and pass that along to the original exit block.
1001+
auto remapExitArguments = [this](SILBasicBlock *exitingBlock,
1002+
SILBasicBlock *exitBlock) {
1003+
SmallVector<SILValue, 8> headerArgs;
1004+
for (SILValue arg : exitBlock->getArguments()) {
1005+
headerArgs.push_back(exitArgHeaderArgMap[arg]);
1006+
}
1007+
appendArguments(exitingBlock->getTerminator(), exitBlock, headerArgs);
1008+
};
1009+
9371010
while (curBlockIter != exitBlocks.end()) {
9381011
SILBasicBlock *newBlock = createBlockOutsideLoop();
9391012
SILBasicBlock *trueBlock = *curBlockIter++;
@@ -959,10 +1032,18 @@ SILBasicBlock *SingleExitLoopTransformer::createNewExitBlockWithDemux(
9591032
condTensorInst->getResults()[0], builder, headerLocation, *deviceInfo);
9601033
builder.createCondBranch(headerLocation, condValue->getResults()[0],
9611034
trueBlock, demuxBlock);
1035+
1036+
if (TFNoUndefsInSESE) {
1037+
remapExitArguments(newBlock, trueBlock);
1038+
remapExitArguments(newBlock, demuxBlock);
1039+
}
9621040
demuxBlock = newBlock;
9631041
}
9641042
builder.setInsertionPoint(newExitBlock);
9651043
builder.createBranch(headerLocation, demuxBlock);
1044+
if (TFNoUndefsInSESE) {
1045+
remapExitArguments(newExitBlock, demuxBlock);
1046+
}
9661047
return newExitBlock;
9671048
}
9681049

@@ -1061,7 +1142,11 @@ void SESERegionBuilder::ensureSingleExitFromLoops() {
10611142
changed |= loopChanged;
10621143
}
10631144
if (changed) {
1064-
splitAllCondBrCriticalEdgesWithNonTrivialArgs(*F, nullptr, &LI);
1145+
splitAllCondBrCriticalEdgesWithNonTrivialArgs(*F, &DI, &LI);
1146+
contractUncondBranches(F, &DI, &LI);
1147+
// TODO(https://bugs.swift.org/browse/SR-8336): the transformations here are
1148+
// simple that we should be able to incrementally update PDI.
1149+
PDI.recalculate(*F);
10651150
}
10661151
}
10671152

lib/SILOptimizer/Utils/CFG.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,7 @@ bool swift::mergeBasicBlockWithSuccessor(SILBasicBlock *BB, DominanceInfo *DT,
822822
auto *BBNode = DT->getNode(BB);
823823
SmallVector<DominanceInfoNode *, 8> Children(SuccBBNode->begin(),
824824
SuccBBNode->end());
825-
for (auto *ChildNode : *SuccBBNode)
825+
for (auto *ChildNode : Children)
826826
DT->changeImmediateDominator(ChildNode, BBNode);
827827

828828
DT->eraseNode(SuccBB);

test/TensorFlow/sese_loop_canonicalization.sil

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,93 @@ bb2:
266266
bb3 (%9 : $Builtin.Int32):
267267
return %9 : $Builtin.Int32
268268
}
269+
270+
// The SIL code is a vastly simplified version of the SIL obtained from swift
271+
// for the following function in -O mode:
272+
// public func natSumWithBreak(_ breakIndex: Int32) -> Tensor<Int32>{
273+
// var i: Int32 = 1
274+
// var sum = Tensor<Int32>(0)
275+
// let maxCount: Int32 = 100
276+
// while i <= maxCount {
277+
// sum += i
278+
// i += 1
279+
// var j = Tensor<Int32>(breakIndex * i)
280+
// if (i == breakIndex) {
281+
// sum -= j
282+
// break
283+
// }
284+
// j -= j
285+
// sum += j
286+
// }
287+
// sum += sum
288+
// return sum
289+
// }
290+
//
291+
// This is an example where the nodes that lead to the common post dominator of all exit nodes
292+
// is also reachable from nodes outside of the loop. The key issue is that bb9 is the common
293+
// post dominator of the exit blocks bb5 and bb8 of the loop. As a first step, we attempt to
294+
// make the common post dominator the single exit block. In order to do that, we need to
295+
// move the blocks reachable from exit blocks bb5 and bb8 up to, but excluding bb9, into the loop.
296+
// In this example, bb3 needs to be moved into the loop, but is reachable from from bb1 (outside
297+
// the loop) and bb8 (which will be moved inside the loop). To deal with this case, we simply clone
298+
// bb3 before moving.
299+
//
300+
// CHECK-LABEL: --- XLA CFG Canonicalize: $loopThatRequiresNodeCloning
301+
// CHECK: [sequence
302+
// CHECK: {condition Header: bb0
303+
// CHECK: block bb1
304+
// CHECK: [sequence
305+
// CHECK: <while Preheader: bb2, Header: bb9, exit: bb11
306+
// CHECK: [sequence
307+
// CHECK: {condition Header: bb3
308+
// CHECK: block bb4
309+
// CHECK: {condition Header: bb5
310+
// CHECK: block bb7
311+
// CHECK: block bb6}}
312+
// CHECK: block bb10]>
313+
// CHECK: block bb11]}
314+
// CHECK: block bb8]
315+
// CHECK: --- XLA CFG Canonicalize end
316+
sil @$loopThatRequiresNodeCloning : $@convention(thin) (Builtin.Int32, Builtin.Int32) -> Builtin.Int32 {
317+
bb0(%0 : $Builtin.Int32, %1 : $Builtin.Int32):
318+
%2 = integer_literal $Builtin.Int32, 1
319+
%3 = integer_literal $Builtin.Int32, 100
320+
%4 = builtin "sadd_with_overflow_Int32"(%0 : $Builtin.Int32, %2 : $Builtin.Int32) : $Builtin.Int32
321+
%5 = builtin "cmp_slt_Int32"(%4 : $Builtin.Int32, %3 : $Builtin.Int32) : $Builtin.Int1
322+
cond_br %5, bb1, bb2
323+
324+
bb1:
325+
// First arg is sum, the second arg is loop counter.
326+
br bb3(%4 : $Builtin.Int32, %0 : $Builtin.Int32)
327+
328+
bb2:
329+
br bb4(%4 : $Builtin.Int32, %0 : $Builtin.Int32)
330+
331+
bb3(%6 : $Builtin.Int32, %7 : $Builtin.Int32):
332+
%8 = builtin "ssub_with_overflow_Int32"(%6 : $Builtin.Int32, %7 : $Builtin.Int32) : $Builtin.Int32
333+
br bb9(%8 : $Builtin.Int32)
334+
335+
bb4(%9 : $Builtin.Int32, %10 : $Builtin.Int32):
336+
%11 = builtin "sadd_with_overflow_Int32"(%10 : $Builtin.Int32, %2 : $Builtin.Int32) : $Builtin.Int32
337+
%12 = builtin "sadd_with_overflow_Int32"(%9 : $Builtin.Int32, %10 : $Builtin.Int32) : $Builtin.Int32
338+
%13 = builtin "cmp_slt_Int32"(%11 : $Builtin.Int32, %1 : $Builtin.Int32) : $Builtin.Int1
339+
cond_br %13, bb5, bb6
340+
341+
bb5:
342+
br bb9(%12 : $Builtin.Int32)
343+
344+
bb6:
345+
%14 = builtin "sadd_with_overflow_Int32"(%11 : $Builtin.Int32, %2 : $Builtin.Int32) : $Builtin.Int32
346+
%15 = builtin "sadd_with_overflow_Int32"(%12 : $Builtin.Int32, %14 : $Builtin.Int32) : $Builtin.Int32
347+
%16 = builtin "cmp_slt_Int32"(%14 : $Builtin.Int32, %1 : $Builtin.Int32) : $Builtin.Int1
348+
cond_br %16, bb8, bb7
349+
350+
bb7:
351+
br bb4(%15 : $Builtin.Int32, %14 : $Builtin.Int32)
352+
353+
bb8:
354+
br bb3(%15 : $Builtin.Int32, %14 : $Builtin.Int32)
355+
356+
bb9(%18 : $Builtin.Int32):
357+
return %18 : $Builtin.Int32
358+
}

0 commit comments

Comments
 (0)