Skip to content

[Coroutines] Move the SuspendCrossingInfo analysis helper into its own header/source #106306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/lib/Transforms/Coroutines/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ add_llvm_component_library(LLVMCoroutines
CoroElide.cpp
CoroFrame.cpp
CoroSplit.cpp
SuspendCrossingInfo.cpp

ADDITIONAL_HEADER_DIRS
${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms/Coroutines
Expand Down
322 changes: 12 additions & 310 deletions llvm/lib/Transforms/Coroutines/CoroFrame.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
//===----------------------------------------------------------------------===//

#include "CoroInternal.h"
#include "SuspendCrossingInfo.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/ScopeExit.h"
Expand Down Expand Up @@ -51,315 +52,6 @@ extern cl::opt<bool> UseNewDbgInfoFormat;
// "coro-frame", which results in leaner debug spew.
#define DEBUG_TYPE "coro-suspend-crossing"

enum { SmallVectorThreshold = 32 };

// Provides two way mapping between the blocks and numbers.
namespace {
class BlockToIndexMapping {
SmallVector<BasicBlock *, SmallVectorThreshold> V;

public:
size_t size() const { return V.size(); }

BlockToIndexMapping(Function &F) {
for (BasicBlock &BB : F)
V.push_back(&BB);
llvm::sort(V);
}

size_t blockToIndex(BasicBlock const *BB) const {
auto *I = llvm::lower_bound(V, BB);
assert(I != V.end() && *I == BB && "BasicBlockNumberng: Unknown block");
return I - V.begin();
}

BasicBlock *indexToBlock(unsigned Index) const { return V[Index]; }
};
} // end anonymous namespace

// The SuspendCrossingInfo maintains data that allows to answer a question
// whether given two BasicBlocks A and B there is a path from A to B that
// passes through a suspend point.
//
// For every basic block 'i' it maintains a BlockData that consists of:
// Consumes: a bit vector which contains a set of indices of blocks that can
// reach block 'i'. A block can trivially reach itself.
// Kills: a bit vector which contains a set of indices of blocks that can
// reach block 'i' but there is a path crossing a suspend point
// not repeating 'i' (path to 'i' without cycles containing 'i').
// Suspend: a boolean indicating whether block 'i' contains a suspend point.
// End: a boolean indicating whether block 'i' contains a coro.end intrinsic.
// KillLoop: There is a path from 'i' to 'i' not otherwise repeating 'i' that
// crosses a suspend point.
//
namespace {
class SuspendCrossingInfo {
BlockToIndexMapping Mapping;

struct BlockData {
BitVector Consumes;
BitVector Kills;
bool Suspend = false;
bool End = false;
bool KillLoop = false;
bool Changed = false;
};
SmallVector<BlockData, SmallVectorThreshold> Block;

iterator_range<pred_iterator> predecessors(BlockData const &BD) const {
BasicBlock *BB = Mapping.indexToBlock(&BD - &Block[0]);
return llvm::predecessors(BB);
}

BlockData &getBlockData(BasicBlock *BB) {
return Block[Mapping.blockToIndex(BB)];
}

/// Compute the BlockData for the current function in one iteration.
/// Initialize - Whether this is the first iteration, we can optimize
/// the initial case a little bit by manual loop switch.
/// Returns whether the BlockData changes in this iteration.
template <bool Initialize = false>
bool computeBlockData(const ReversePostOrderTraversal<Function *> &RPOT);

public:
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
void dump() const;
void dump(StringRef Label, BitVector const &BV,
const ReversePostOrderTraversal<Function *> &RPOT) const;
#endif

SuspendCrossingInfo(Function &F, coro::Shape &Shape);

/// Returns true if there is a path from \p From to \p To crossing a suspend
/// point without crossing \p From a 2nd time.
bool hasPathCrossingSuspendPoint(BasicBlock *From, BasicBlock *To) const {
size_t const FromIndex = Mapping.blockToIndex(From);
size_t const ToIndex = Mapping.blockToIndex(To);
bool const Result = Block[ToIndex].Kills[FromIndex];
LLVM_DEBUG(dbgs() << From->getName() << " => " << To->getName()
<< " answer is " << Result << "\n");
return Result;
}

/// Returns true if there is a path from \p From to \p To crossing a suspend
/// point without crossing \p From a 2nd time. If \p From is the same as \p To
/// this will also check if there is a looping path crossing a suspend point.
bool hasPathOrLoopCrossingSuspendPoint(BasicBlock *From,
BasicBlock *To) const {
size_t const FromIndex = Mapping.blockToIndex(From);
size_t const ToIndex = Mapping.blockToIndex(To);
bool Result = Block[ToIndex].Kills[FromIndex] ||
(From == To && Block[ToIndex].KillLoop);
LLVM_DEBUG(dbgs() << From->getName() << " => " << To->getName()
<< " answer is " << Result << " (path or loop)\n");
return Result;
}

bool isDefinitionAcrossSuspend(BasicBlock *DefBB, User *U) const {
auto *I = cast<Instruction>(U);

// We rewrote PHINodes, so that only the ones with exactly one incoming
// value need to be analyzed.
if (auto *PN = dyn_cast<PHINode>(I))
if (PN->getNumIncomingValues() > 1)
return false;

BasicBlock *UseBB = I->getParent();

// As a special case, treat uses by an llvm.coro.suspend.retcon or an
// llvm.coro.suspend.async as if they were uses in the suspend's single
// predecessor: the uses conceptually occur before the suspend.
if (isa<CoroSuspendRetconInst>(I) || isa<CoroSuspendAsyncInst>(I)) {
UseBB = UseBB->getSinglePredecessor();
assert(UseBB && "should have split coro.suspend into its own block");
}

return hasPathCrossingSuspendPoint(DefBB, UseBB);
}

bool isDefinitionAcrossSuspend(Argument &A, User *U) const {
return isDefinitionAcrossSuspend(&A.getParent()->getEntryBlock(), U);
}

bool isDefinitionAcrossSuspend(Instruction &I, User *U) const {
auto *DefBB = I.getParent();

// As a special case, treat values produced by an llvm.coro.suspend.*
// as if they were defined in the single successor: the uses
// conceptually occur after the suspend.
if (isa<AnyCoroSuspendInst>(I)) {
DefBB = DefBB->getSingleSuccessor();
assert(DefBB && "should have split coro.suspend into its own block");
}

return isDefinitionAcrossSuspend(DefBB, U);
}

bool isDefinitionAcrossSuspend(Value &V, User *U) const {
if (auto *Arg = dyn_cast<Argument>(&V))
return isDefinitionAcrossSuspend(*Arg, U);
if (auto *Inst = dyn_cast<Instruction>(&V))
return isDefinitionAcrossSuspend(*Inst, U);

llvm_unreachable(
"Coroutine could only collect Argument and Instruction now.");
}
};
} // end anonymous namespace

#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
static std::string getBasicBlockLabel(const BasicBlock *BB) {
if (BB->hasName())
return BB->getName().str();

std::string S;
raw_string_ostream OS(S);
BB->printAsOperand(OS, false);
return OS.str().substr(1);
}

LLVM_DUMP_METHOD void SuspendCrossingInfo::dump(
StringRef Label, BitVector const &BV,
const ReversePostOrderTraversal<Function *> &RPOT) const {
dbgs() << Label << ":";
for (const BasicBlock *BB : RPOT) {
auto BBNo = Mapping.blockToIndex(BB);
if (BV[BBNo])
dbgs() << " " << getBasicBlockLabel(BB);
}
dbgs() << "\n";
}

LLVM_DUMP_METHOD void SuspendCrossingInfo::dump() const {
if (Block.empty())
return;

BasicBlock *const B = Mapping.indexToBlock(0);
Function *F = B->getParent();

ReversePostOrderTraversal<Function *> RPOT(F);
for (const BasicBlock *BB : RPOT) {
auto BBNo = Mapping.blockToIndex(BB);
dbgs() << getBasicBlockLabel(BB) << ":\n";
dump(" Consumes", Block[BBNo].Consumes, RPOT);
dump(" Kills", Block[BBNo].Kills, RPOT);
}
dbgs() << "\n";
}
#endif

template <bool Initialize>
bool SuspendCrossingInfo::computeBlockData(
const ReversePostOrderTraversal<Function *> &RPOT) {
bool Changed = false;

for (const BasicBlock *BB : RPOT) {
auto BBNo = Mapping.blockToIndex(BB);
auto &B = Block[BBNo];

// We don't need to count the predecessors when initialization.
if constexpr (!Initialize)
// If all the predecessors of the current Block don't change,
// the BlockData for the current block must not change too.
if (all_of(predecessors(B), [this](BasicBlock *BB) {
return !Block[Mapping.blockToIndex(BB)].Changed;
})) {
B.Changed = false;
continue;
}

// Saved Consumes and Kills bitsets so that it is easy to see
// if anything changed after propagation.
auto SavedConsumes = B.Consumes;
auto SavedKills = B.Kills;

for (BasicBlock *PI : predecessors(B)) {
auto PrevNo = Mapping.blockToIndex(PI);
auto &P = Block[PrevNo];

// Propagate Kills and Consumes from predecessors into B.
B.Consumes |= P.Consumes;
B.Kills |= P.Kills;

// If block P is a suspend block, it should propagate kills into block
// B for every block P consumes.
if (P.Suspend)
B.Kills |= P.Consumes;
}

if (B.Suspend) {
// If block B is a suspend block, it should kill all of the blocks it
// consumes.
B.Kills |= B.Consumes;
} else if (B.End) {
// If block B is an end block, it should not propagate kills as the
// blocks following coro.end() are reached during initial invocation
// of the coroutine while all the data are still available on the
// stack or in the registers.
B.Kills.reset();
} else {
// This is reached when B block it not Suspend nor coro.end and it
// need to make sure that it is not in the kill set.
B.KillLoop |= B.Kills[BBNo];
B.Kills.reset(BBNo);
}

if constexpr (!Initialize) {
B.Changed = (B.Kills != SavedKills) || (B.Consumes != SavedConsumes);
Changed |= B.Changed;
}
}

return Changed;
}

SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape)
: Mapping(F) {
const size_t N = Mapping.size();
Block.resize(N);

// Initialize every block so that it consumes itself
for (size_t I = 0; I < N; ++I) {
auto &B = Block[I];
B.Consumes.resize(N);
B.Kills.resize(N);
B.Consumes.set(I);
B.Changed = true;
}

// Mark all CoroEnd Blocks. We do not propagate Kills beyond coro.ends as
// the code beyond coro.end is reachable during initial invocation of the
// coroutine.
for (auto *CE : Shape.CoroEnds)
getBlockData(CE->getParent()).End = true;

// Mark all suspend blocks and indicate that they kill everything they
// consume. Note, that crossing coro.save also requires a spill, as any code
// between coro.save and coro.suspend may resume the coroutine and all of the
// state needs to be saved by that time.
auto markSuspendBlock = [&](IntrinsicInst *BarrierInst) {
BasicBlock *SuspendBlock = BarrierInst->getParent();
auto &B = getBlockData(SuspendBlock);
B.Suspend = true;
B.Kills |= B.Consumes;
};
for (auto *CSI : Shape.CoroSuspends) {
markSuspendBlock(CSI);
if (auto *Save = CSI->getCoroSave())
markSuspendBlock(Save);
}

// It is considered to be faster to use RPO traversal for forward-edges
// dataflow analysis.
ReversePostOrderTraversal<Function *> RPOT(&F);
computeBlockData</*Initialize=*/true>(RPOT);
while (computeBlockData</*Initialize*/ false>(RPOT))
;

LLVM_DEBUG(dump());
}

namespace {

// RematGraph is used to construct a DAG for rematerializable instructions
Expand Down Expand Up @@ -438,6 +130,16 @@ struct RematGraph {
}

#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
static std::string getBasicBlockLabel(const BasicBlock *BB) {
if (BB->hasName())
return BB->getName().str();

std::string S;
raw_string_ostream OS(S);
BB->printAsOperand(OS, false);
return OS.str().substr(1);
}

void dump() const {
dbgs() << "Entry (";
dbgs() << getBasicBlockLabel(EntryNode->Node->getParent());
Expand Down Expand Up @@ -3159,7 +2861,7 @@ void coro::buildCoroutineFrame(
rewritePHIs(F);

// Build suspend crossing info.
SuspendCrossingInfo Checker(F, Shape);
SuspendCrossingInfo Checker(F, Shape.CoroSuspends, Shape.CoroEnds);

doRematerializations(F, Checker, MaterializableCallback);

Expand Down
Loading
Loading