Skip to content

Clone blocks that are reachable from outside the loop. #19570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 28, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 130 additions & 45 deletions lib/SILOptimizer/Mandatory/TFCanonicalizeCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "swift/SIL/GraphOperationBuilder.h"
#include "swift/SIL/LoopInfo.h"
#include "swift/SIL/SILBuilder.h"
#include "swift/SIL/SILCloner.h"
#include "swift/SIL/SILConstants.h"
#include "swift/SIL/SILUndef.h"
#include "llvm/ADT/iterator_range.h"
Expand Down Expand Up @@ -321,6 +322,58 @@ static SILValue createTFIntegerConst(GraphFunctionDeviceInfo &deviceInfo,
return constNode->getResults()[0];
}

namespace {

class BasicBlockCloner : public SILClonerWithScopes<BasicBlockCloner> {
private:
/// The flag to track if this cloner was used to clone any blocks.
bool cloned;

public:
BasicBlockCloner(SILFunction &F)
: SILClonerWithScopes(F), cloned(false) {}

bool hasCloned() const { return cloned; }

/// Return a cloned block.
SILBasicBlock *cloneBlock(SILBasicBlock *bb) {
auto bbIt = BBMap.find(bb);
if (bbIt != BBMap.end())
return bbIt->second;

cloned = true;

SILFunction &F = getBuilder().getFunction();
SILBasicBlock *newBB = F.createBasicBlock();
getBuilder().setInsertionPoint(newBB);
BBMap[bb] = newBB;
// If the basic block has arguments, clone them as well.
for (auto *arg : bb->getArguments()) {
// Create a new argument and copy it into the ValueMap so future
// references use it.
ValueMap[arg] = newBB->createPHIArgument(
arg->getType(), arg->getOwnershipKind(), arg->getDecl());
}
// Clone all the instructions.
for (auto &inst : *bb) {
visit(&inst);
}
return newBB;
}

/// Handle references to basic blocks when cloning.
SILBasicBlock *remapBasicBlock(SILBasicBlock *bb) {
// If the block was not cloned by this cloner, directly reference it.
// Otherwise, use the cloned block.
auto bbIt = BBMap.find(bb);
if (bbIt != BBMap.end())
return bbIt->second;
return bb;
}
};

} // namespace

// A helper class to transform a loop to have a single exit from the header.
class SingleExitLoopTransformer {
public:
Expand Down Expand Up @@ -432,6 +485,8 @@ class SingleExitLoopTransformer {
llvm::DenseMap<SILValue, SILValue> escapingValueSubstMap;
/// Similar to escapingValueSubstMap, but arguments of exit blocks.
llvm::DenseMap<SILValue, SILValue> exitArgSubstMap;
// Map from an arg of an exit block to the corresponding newly added header arg.
llvm::DenseMap<SILValue, SILValue> exitArgHeaderArgMap;
};

void SingleExitLoopTransformer::initialize() {
Expand Down Expand Up @@ -492,6 +547,8 @@ void SingleExitLoopTransformer::initialize() {
}

void SingleExitLoopTransformer::ensureSingleExitBlock() {
BasicBlockCloner cloner(*currentFn);

// Identify the common post dominator
SILPrintContext printContext(llvm::dbgs());
SmallVector<SILBasicBlock*, 8> exitBlockList;
Expand Down Expand Up @@ -528,21 +585,31 @@ void SingleExitLoopTransformer::ensureSingleExitBlock() {
// appropriately and not deal with touching stale memory.
auto succs = current->getSuccessors();
auto *succ = succs[edgeIdx].getBB();
// Skip if (1) already processed, (2) reached common pd, or (3)
// block has in edges from outside the loop. In the last case, we will
// need to clone the blocks.
// Skip if (1) already processed or (2) reached common pd.
if (blocksToBeMoved.count(succ) > 0 || succ == nearestCommonPD) {
continue;
}
if (!DI->properlyDominates(header, succ)) {
// Split this edge so that we don't mess up arguments passed in
// from other predecessors of succ.
if (succ->getNumArguments() > 0) {
splitEdge(current->getTerminator(), edgeIdx, DI, LI);
}

if (DI->properlyDominates(header, succ)) {
worklist.insert(succ);
continue;
}
worklist.insert(succ);
// If `succ` is not dominated by `header`, then `succ` is reachable from
// a node outside of this loop. We might have to clone `succ` in such
// cases.

// Before cloning make sure that header -> succ is *not* backedge of a
// parent loop. This can happen when we have labeled breaks in loops. We
// cannot clone the blocks in such cases. Simply continue. This is still
// OK for our purposes because we will find an equivalent value at the
// header for any value that escapes along this edge.
if (DI->properlyDominates(succ, header)) continue;

// Clone the block and rewire the edge.
SILBasicBlock *clonedSucc = cloner.cloneBlock(succ);
changeBranchTarget(current->getTerminator(), edgeIdx, clonedSucc,
/*preserveArgs*/ true);
worklist.insert(clonedSucc);
}
}
}
Expand Down Expand Up @@ -578,6 +645,12 @@ void SingleExitLoopTransformer::ensureSingleExitBlock() {
}
loop->addBasicBlockToLoop(outsideBlock, LI->getBase());
}
if (cloner.hasCloned()) {
// TODO(https://bugs.swift.org/browse/SR-8336): the transformations here are
// simple that we should be able to incrementally update the DI & PDI.
DI->recalculate(*currentFn);
PDI->recalculate(*currentFn);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, I think this approach is totally fine, assuming that this isn't commonly necessary. If it is commonly necessary, it would be interesting to see what the patterns are and maybe a trivial local update would work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This should not be commonly necessary as it only happens when we clone blocks. I will leave it for now.

}
}

llvm::DenseMap<SILValue, SILValue>
Expand Down Expand Up @@ -699,30 +772,30 @@ SingleExitLoopTransformer::createNewHeader() {
}
header->dropAllArguments();
// Add phi arguments in the new header corresponding to the escaping values.
auto addArgument =
[this, newHeader](SILValue escapingValue) {
SILValue newValue = newHeader->createPHIArgument(
escapingValue->getType(), escapingValue.getOwnershipKind());
// Replace uses *outside* of the loop with the new value.
auto UI = escapingValue->use_begin(), E = escapingValue->use_end();
while (UI != E) {
Operand *use = *UI;
// Increment iterator before we invalidate it
// when we invoke Operand::Set below.
++UI;
if (loop->contains(use->getUser()->getParent())) {
continue;
}
use->set(newValue);
}
};

for (const auto &kv : escapingValueSubstMap) {
addArgument(kv.first);
SILValue escapingValue = kv.first;
SILValue newValue = newHeader->createPHIArgument(
escapingValue->getType(), escapingValue.getOwnershipKind());
// Replace uses *outside* of the loop with the new value.
auto UI = escapingValue->use_begin(), E = escapingValue->use_end();
while (UI != E) {
Operand *use = *UI;
// Increment iterator before we invalidate it
// when we invoke Operand::Set below.
++UI;
if (loop->contains(use->getUser()->getParent())) {
continue;
}
use->set(newValue);
}
}
if (TFNoUndefsInSESE) {
// Add arguments in the new header corresponding to exit block arguments.
for (const auto &kv : exitArgSubstMap) {
addArgument(kv.first);
SILValue arg = kv.first;
SILValue newValue =
newHeader->createPHIArgument(arg->getType(), arg.getOwnershipKind());
exitArgHeaderArgMap[kv.first] = newValue;
}
}
// An integer to identify the exit edge.
Expand Down Expand Up @@ -901,15 +974,6 @@ SingleExitLoopTransformer::patchEdges(SILBasicBlock *newHeader,
SILBasicBlock *SingleExitLoopTransformer::createNewExitBlockWithDemux(
const llvm::DenseMap<SILBasicBlock *, intmax_t> &exitIndices,
SILValue exitIndexArg) {
if (TFNoUndefsInSESE) {
// Drop all arguments as we have moved them into the headers arguments.
for (SILBasicBlock *exitBlock : exitBlocks) {
exitBlock->dropAllArguments();
}
}
if (exitBlocks.size() == 1) {
return *exitBlocks.begin();
}
auto createBlockOutsideLoop = [this]() {
SILBasicBlock *newBlock = currentFn->createBasicBlock();
SILLoop *parentLoop = loop->getParentLoop();
Expand All @@ -919,11 +983,9 @@ SILBasicBlock *SingleExitLoopTransformer::createNewExitBlockWithDemux(
return newBlock;
};

// Create a new exit block.
// FIXME: We can avoid creating an additional block and instead connect the
// header directly to the demuxBlock created in the loop below. Alternatively,
// we can also use contractUncondBranches in TFParititon.cpp to remove this
// block later.
// Create a new exit block. Strictly, we don't always need this block, but it
// makes it slightly easier to implement the demux blocks. contractUncondEdges
// will merge this block away if appropriate.
SILBasicBlock *newExitBlock = createBlockOutsideLoop();

SILBuilder builder(newExitBlock);
Expand All @@ -934,6 +996,17 @@ SILBasicBlock *SingleExitLoopTransformer::createNewExitBlockWithDemux(
SILLocation headerLocation =
getUserSourceLocation(header->getTerminator()->getDebugLocation());

// Find the arguments at the header that were added for the exit arguments
// and pass that along to the original exit block.
auto remapExitArguments = [this](SILBasicBlock *exitingBlock,
SILBasicBlock *exitBlock) {
SmallVector<SILValue, 8> headerArgs;
for (SILValue arg : exitBlock->getArguments()) {
headerArgs.push_back(exitArgHeaderArgMap[arg]);
}
appendArguments(exitingBlock->getTerminator(), exitBlock, headerArgs);
};

while (curBlockIter != exitBlocks.end()) {
SILBasicBlock *newBlock = createBlockOutsideLoop();
SILBasicBlock *trueBlock = *curBlockIter++;
Expand All @@ -959,10 +1032,18 @@ SILBasicBlock *SingleExitLoopTransformer::createNewExitBlockWithDemux(
condTensorInst->getResults()[0], builder, headerLocation, *deviceInfo);
builder.createCondBranch(headerLocation, condValue->getResults()[0],
trueBlock, demuxBlock);

if (TFNoUndefsInSESE) {
remapExitArguments(newBlock, trueBlock);
remapExitArguments(newBlock, demuxBlock);
}
demuxBlock = newBlock;
}
builder.setInsertionPoint(newExitBlock);
builder.createBranch(headerLocation, demuxBlock);
if (TFNoUndefsInSESE) {
remapExitArguments(newExitBlock, demuxBlock);
}
return newExitBlock;
}

Expand Down Expand Up @@ -1061,7 +1142,11 @@ void SESERegionBuilder::ensureSingleExitFromLoops() {
changed |= loopChanged;
}
if (changed) {
splitAllCondBrCriticalEdgesWithNonTrivialArgs(*F, nullptr, &LI);
splitAllCondBrCriticalEdgesWithNonTrivialArgs(*F, &DI, &LI);
contractUncondBranches(F, &DI, &LI);
// TODO(https://bugs.swift.org/browse/SR-8336): the transformations here are
// simple that we should be able to incrementally update PDI.
PDI.recalculate(*F);
}
}

Expand Down
2 changes: 1 addition & 1 deletion lib/SILOptimizer/Utils/CFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ bool swift::mergeBasicBlockWithSuccessor(SILBasicBlock *BB, DominanceInfo *DT,
auto *BBNode = DT->getNode(BB);
SmallVector<DominanceInfoNode *, 8> Children(SuccBBNode->begin(),
SuccBBNode->end());
for (auto *ChildNode : *SuccBBNode)
for (auto *ChildNode : Children)
DT->changeImmediateDominator(ChildNode, BBNode);

DT->eraseNode(SuccBB);
Expand Down
90 changes: 90 additions & 0 deletions test/TensorFlow/sese_loop_canonicalization.sil
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,93 @@ bb2:
bb3 (%9 : $Builtin.Int32):
return %9 : $Builtin.Int32
}

// The SIL code is a vastly simplified version of the SIL obtained from swift
// for the following function in -O mode:
// public func natSumWithBreak(_ breakIndex: Int32) -> Tensor<Int32>{
// var i: Int32 = 1
// var sum = Tensor<Int32>(0)
// let maxCount: Int32 = 100
// while i <= maxCount {
// sum += i
// i += 1
// var j = Tensor<Int32>(breakIndex * i)
// if (i == breakIndex) {
// sum -= j
// break
// }
// j -= j
// sum += j
// }
// sum += sum
// return sum
// }
//
// This is an example where the nodes that lead to the common post dominator of all exit nodes
// is also reachable from nodes outside of the loop. The key issue is that bb9 is the common
// post dominator of the exit blocks bb5 and bb8 of the loop. As a first step, we attempt to
// make the common post dominator the single exit block. In order to do that, we need to
// move the blocks reachable from exit blocks bb5 and bb8 up to, but excluding bb9, into the loop.
// In this example, bb3 needs to be moved into the loop, but is reachable from from bb1 (outside
// the loop) and bb8 (which will be moved inside the loop). To deal with this case, we simply clone
// bb3 before moving.
//
// CHECK-LABEL: --- XLA CFG Canonicalize: $loopThatRequiresNodeCloning
// CHECK: [sequence
// CHECK: {condition Header: bb0
// CHECK: block bb1
// CHECK: [sequence
// CHECK: <while Preheader: bb2, Header: bb9, exit: bb11
// CHECK: [sequence
// CHECK: {condition Header: bb3
// CHECK: block bb4
// CHECK: {condition Header: bb5
// CHECK: block bb7
// CHECK: block bb6}}
// CHECK: block bb10]>
// CHECK: block bb11]}
// CHECK: block bb8]
// CHECK: --- XLA CFG Canonicalize end
sil @$loopThatRequiresNodeCloning : $@convention(thin) (Builtin.Int32, Builtin.Int32) -> Builtin.Int32 {
bb0(%0 : $Builtin.Int32, %1 : $Builtin.Int32):
%2 = integer_literal $Builtin.Int32, 1
%3 = integer_literal $Builtin.Int32, 100
%4 = builtin "sadd_with_overflow_Int32"(%0 : $Builtin.Int32, %2 : $Builtin.Int32) : $Builtin.Int32
%5 = builtin "cmp_slt_Int32"(%4 : $Builtin.Int32, %3 : $Builtin.Int32) : $Builtin.Int1
cond_br %5, bb1, bb2

bb1:
// First arg is sum, the second arg is loop counter.
br bb3(%4 : $Builtin.Int32, %0 : $Builtin.Int32)

bb2:
br bb4(%4 : $Builtin.Int32, %0 : $Builtin.Int32)

bb3(%6 : $Builtin.Int32, %7 : $Builtin.Int32):
%8 = builtin "ssub_with_overflow_Int32"(%6 : $Builtin.Int32, %7 : $Builtin.Int32) : $Builtin.Int32
br bb9(%8 : $Builtin.Int32)

bb4(%9 : $Builtin.Int32, %10 : $Builtin.Int32):
%11 = builtin "sadd_with_overflow_Int32"(%10 : $Builtin.Int32, %2 : $Builtin.Int32) : $Builtin.Int32
%12 = builtin "sadd_with_overflow_Int32"(%9 : $Builtin.Int32, %10 : $Builtin.Int32) : $Builtin.Int32
%13 = builtin "cmp_slt_Int32"(%11 : $Builtin.Int32, %1 : $Builtin.Int32) : $Builtin.Int1
cond_br %13, bb5, bb6

bb5:
br bb9(%12 : $Builtin.Int32)

bb6:
%14 = builtin "sadd_with_overflow_Int32"(%11 : $Builtin.Int32, %2 : $Builtin.Int32) : $Builtin.Int32
%15 = builtin "sadd_with_overflow_Int32"(%12 : $Builtin.Int32, %14 : $Builtin.Int32) : $Builtin.Int32
%16 = builtin "cmp_slt_Int32"(%14 : $Builtin.Int32, %1 : $Builtin.Int32) : $Builtin.Int1
cond_br %16, bb8, bb7

bb7:
br bb4(%15 : $Builtin.Int32, %14 : $Builtin.Int32)

bb8:
br bb3(%15 : $Builtin.Int32, %14 : $Builtin.Int32)

bb9(%18 : $Builtin.Int32):
return %18 : $Builtin.Int32
}