Skip to content

[Coroutines] Refactor CoroShape::buildFrom for future use by ABI objects #108623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 46 additions & 8 deletions llvm/lib/Transforms/Coroutines/CoroShape.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,45 @@ enum class ABI {
// Holds structural Coroutine Intrinsics for a particular function and other
// values used during CoroSplit pass.
struct LLVM_LIBRARY_VISIBILITY Shape {
CoroBeginInst *CoroBegin;
CoroBeginInst *CoroBegin = nullptr;
SmallVector<AnyCoroEndInst *, 4> CoroEnds;
SmallVector<CoroSizeInst *, 2> CoroSizes;
SmallVector<CoroAlignInst *, 2> CoroAligns;
SmallVector<AnyCoroSuspendInst *, 4> CoroSuspends;
SmallVector<CallInst *, 2> SwiftErrorOps;
SmallVector<CoroAwaitSuspendInst *, 4> CoroAwaitSuspends;
SmallVector<CallInst *, 2> SymmetricTransfers;

// Values invalidated by replaceSwiftErrorOps()
SmallVector<CallInst *, 2> SwiftErrorOps;

void clear() {
CoroBegin = nullptr;
CoroEnds.clear();
CoroSizes.clear();
CoroAligns.clear();
CoroSuspends.clear();
CoroAwaitSuspends.clear();
SymmetricTransfers.clear();

SwiftErrorOps.clear();

FrameTy = nullptr;
FramePtr = nullptr;
AllocaSpillBlock = nullptr;
}

// Scan the function and collect the above intrinsics for later processing
void analyze(Function &F, SmallVectorImpl<CoroFrameInst *> &CoroFrames,
SmallVectorImpl<CoroSaveInst *> &UnusedCoroSaves);
// If for some reason, we were not able to find coro.begin, bailout.
void invalidateCoroutine(Function &F,
SmallVectorImpl<CoroFrameInst *> &CoroFrames);
// Perform ABI related initial transformation
void initABI();
// Remove orphaned and unnecessary intrinsics
void cleanCoroutine(SmallVectorImpl<CoroFrameInst *> &CoroFrames,
SmallVectorImpl<CoroSaveInst *> &UnusedCoroSaves);

// Field indexes for special fields in the switch lowering.
struct SwitchFieldIndex {
enum {
Expand All @@ -76,11 +106,11 @@ struct LLVM_LIBRARY_VISIBILITY Shape {

coro::ABI ABI;

StructType *FrameTy;
StructType *FrameTy = nullptr;
Align FrameAlign;
uint64_t FrameSize;
Value *FramePtr;
BasicBlock *AllocaSpillBlock;
uint64_t FrameSize = 0;
Value *FramePtr = nullptr;
BasicBlock *AllocaSpillBlock = nullptr;

/// This would only be true if optimization are enabled.
bool OptimizeFrame;
Expand Down Expand Up @@ -237,9 +267,17 @@ struct LLVM_LIBRARY_VISIBILITY Shape {
Shape() = default;
explicit Shape(Function &F, bool OptimizeFrame = false)
: OptimizeFrame(OptimizeFrame) {
buildFrom(F);
SmallVector<CoroFrameInst *, 8> CoroFrames;
SmallVector<CoroSaveInst *, 2> UnusedCoroSaves;

analyze(F, CoroFrames, UnusedCoroSaves);
if (!CoroBegin) {
invalidateCoroutine(F, CoroFrames);
return;
}
initABI();
cleanCoroutine(CoroFrames, UnusedCoroSaves);
}
void buildFrom(Function &F);
};

} // end namespace coro
Expand Down
154 changes: 86 additions & 68 deletions llvm/lib/Transforms/Coroutines/Coroutines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,17 +177,6 @@ void coro::suppressCoroAllocs(LLVMContext &Context,
}
}

static void clear(coro::Shape &Shape) {
Shape.CoroBegin = nullptr;
Shape.CoroEnds.clear();
Shape.CoroSizes.clear();
Shape.CoroSuspends.clear();

Shape.FrameTy = nullptr;
Shape.FramePtr = nullptr;
Shape.AllocaSpillBlock = nullptr;
}

static CoroSaveInst *createCoroSave(CoroBeginInst *CoroBegin,
CoroSuspendInst *SuspendInst) {
Module *M = SuspendInst->getModule();
Expand All @@ -200,13 +189,14 @@ static CoroSaveInst *createCoroSave(CoroBeginInst *CoroBegin,
}

// Collect "interesting" coroutine intrinsics.
void coro::Shape::buildFrom(Function &F) {
void coro::Shape::analyze(Function &F,
SmallVectorImpl<CoroFrameInst *> &CoroFrames,
SmallVectorImpl<CoroSaveInst *> &UnusedCoroSaves) {
clear();

bool HasFinalSuspend = false;
bool HasUnwindCoroEnd = false;
size_t FinalSuspendIndex = 0;
clear(*this);
SmallVector<CoroFrameInst *, 8> CoroFrames;
SmallVector<CoroSaveInst *, 2> UnusedCoroSaves;

for (Instruction &I : instructions(F)) {
// FIXME: coro_await_suspend_* are not proper `IntrinisicInst`s
Expand Down Expand Up @@ -298,15 +288,73 @@ void coro::Shape::buildFrom(Function &F) {
}
}

// If for some reason, we were not able to find coro.begin, bailout.
if (!CoroBegin) {
// If there is no CoroBegin then this is not a coroutine.
if (!CoroBegin)
return;

// Determination of ABI and initializing lowering info
auto Id = CoroBegin->getId();
switch (auto IntrID = Id->getIntrinsicID()) {
case Intrinsic::coro_id: {
ABI = coro::ABI::Switch;
SwitchLowering.HasFinalSuspend = HasFinalSuspend;
SwitchLowering.HasUnwindCoroEnd = HasUnwindCoroEnd;

auto SwitchId = getSwitchCoroId();
SwitchLowering.ResumeSwitch = nullptr;
SwitchLowering.PromiseAlloca = SwitchId->getPromise();
SwitchLowering.ResumeEntryBlock = nullptr;

// Move final suspend to the last element in the CoroSuspends vector.
if (SwitchLowering.HasFinalSuspend &&
FinalSuspendIndex != CoroSuspends.size() - 1)
std::swap(CoroSuspends[FinalSuspendIndex], CoroSuspends.back());
break;
}
case Intrinsic::coro_id_async: {
ABI = coro::ABI::Async;
auto *AsyncId = getAsyncCoroId();
AsyncId->checkWellFormed();
AsyncLowering.Context = AsyncId->getStorage();
AsyncLowering.ContextArgNo = AsyncId->getStorageArgumentIndex();
AsyncLowering.ContextHeaderSize = AsyncId->getStorageSize();
AsyncLowering.ContextAlignment = AsyncId->getStorageAlignment().value();
AsyncLowering.AsyncFuncPointer = AsyncId->getAsyncFunctionPointer();
AsyncLowering.AsyncCC = F.getCallingConv();
break;
}
case Intrinsic::coro_id_retcon:
case Intrinsic::coro_id_retcon_once: {
ABI = IntrID == Intrinsic::coro_id_retcon ? coro::ABI::Retcon
: coro::ABI::RetconOnce;
auto ContinuationId = getRetconCoroId();
ContinuationId->checkWellFormed();
auto Prototype = ContinuationId->getPrototype();
RetconLowering.ResumePrototype = Prototype;
RetconLowering.Alloc = ContinuationId->getAllocFunction();
RetconLowering.Dealloc = ContinuationId->getDeallocFunction();
RetconLowering.ReturnBlock = nullptr;
RetconLowering.IsFrameInlineInStorage = false;
break;
}
default:
llvm_unreachable("coro.begin is not dependent on a coro.id call");
}
}

// If for some reason, we were not able to find coro.begin, bailout.
void coro::Shape::invalidateCoroutine(
Function &F, SmallVectorImpl<CoroFrameInst *> &CoroFrames) {
assert(!CoroBegin);
{
// Replace coro.frame which are supposed to be lowered to the result of
// coro.begin with undef.
auto *Undef = UndefValue::get(PointerType::get(F.getContext(), 0));
for (CoroFrameInst *CF : CoroFrames) {
CF->replaceAllUsesWith(Undef);
CF->eraseFromParent();
}
CoroFrames.clear();

// Replace all coro.suspend with undef and remove related coro.saves if
// present.
Expand All @@ -316,25 +364,18 @@ void coro::Shape::buildFrom(Function &F) {
if (auto *CoroSave = CS->getCoroSave())
CoroSave->eraseFromParent();
}
CoroSuspends.clear();

// Replace all coro.ends with unreachable instruction.
for (AnyCoroEndInst *CE : CoroEnds)
changeToUnreachable(CE);

return;
}
}

auto Id = CoroBegin->getId();
switch (auto IdIntrinsic = Id->getIntrinsicID()) {
case Intrinsic::coro_id: {
auto SwitchId = cast<CoroIdInst>(Id);
this->ABI = coro::ABI::Switch;
this->SwitchLowering.HasFinalSuspend = HasFinalSuspend;
this->SwitchLowering.HasUnwindCoroEnd = HasUnwindCoroEnd;
this->SwitchLowering.ResumeSwitch = nullptr;
this->SwitchLowering.PromiseAlloca = SwitchId->getPromise();
this->SwitchLowering.ResumeEntryBlock = nullptr;

// Perform semantic checking and initialization of the ABI
void coro::Shape::initABI() {
switch (ABI) {
case coro::ABI::Switch: {
for (auto *AnySuspend : CoroSuspends) {
auto Suspend = dyn_cast<CoroSuspendInst>(AnySuspend);
if (!Suspend) {
Expand All @@ -349,33 +390,11 @@ void coro::Shape::buildFrom(Function &F) {
}
break;
}
case Intrinsic::coro_id_async: {
auto *AsyncId = cast<CoroIdAsyncInst>(Id);
AsyncId->checkWellFormed();
this->ABI = coro::ABI::Async;
this->AsyncLowering.Context = AsyncId->getStorage();
this->AsyncLowering.ContextArgNo = AsyncId->getStorageArgumentIndex();
this->AsyncLowering.ContextHeaderSize = AsyncId->getStorageSize();
this->AsyncLowering.ContextAlignment =
AsyncId->getStorageAlignment().value();
this->AsyncLowering.AsyncFuncPointer = AsyncId->getAsyncFunctionPointer();
this->AsyncLowering.AsyncCC = F.getCallingConv();
case coro::ABI::Async: {
break;
};
case Intrinsic::coro_id_retcon:
case Intrinsic::coro_id_retcon_once: {
auto ContinuationId = cast<AnyCoroIdRetconInst>(Id);
ContinuationId->checkWellFormed();
this->ABI = (IdIntrinsic == Intrinsic::coro_id_retcon
? coro::ABI::Retcon
: coro::ABI::RetconOnce);
auto Prototype = ContinuationId->getPrototype();
this->RetconLowering.ResumePrototype = Prototype;
this->RetconLowering.Alloc = ContinuationId->getAllocFunction();
this->RetconLowering.Dealloc = ContinuationId->getDeallocFunction();
this->RetconLowering.ReturnBlock = nullptr;
this->RetconLowering.IsFrameInlineInStorage = false;

case coro::ABI::Retcon:
case coro::ABI::RetconOnce: {
// Determine the result value types, and make sure they match up with
// the values passed to the suspends.
auto ResultTys = getRetconResultTypes();
Expand Down Expand Up @@ -408,7 +427,7 @@ void coro::Shape::buildFrom(Function &F) {

#ifndef NDEBUG
Suspend->dump();
Prototype->getFunctionType()->dump();
RetconLowering.ResumePrototype->getFunctionType()->dump();
#endif
report_fatal_error("argument to coro.suspend.retcon does not "
"match corresponding prototype function result");
Expand All @@ -417,14 +436,14 @@ void coro::Shape::buildFrom(Function &F) {
if (SI != SE || RI != RE) {
#ifndef NDEBUG
Suspend->dump();
Prototype->getFunctionType()->dump();
RetconLowering.ResumePrototype->getFunctionType()->dump();
#endif
report_fatal_error("wrong number of arguments to coro.suspend.retcon");
}

// Check that the result type of the suspend matches the resume types.
Type *SResultTy = Suspend->getType();
ArrayRef<Type*> SuspendResultTys;
ArrayRef<Type *> SuspendResultTys;
if (SResultTy->isVoidTy()) {
// leave as empty array
} else if (auto SResultStructTy = dyn_cast<StructType>(SResultTy)) {
Expand All @@ -436,15 +455,15 @@ void coro::Shape::buildFrom(Function &F) {
if (SuspendResultTys.size() != ResumeTys.size()) {
#ifndef NDEBUG
Suspend->dump();
Prototype->getFunctionType()->dump();
RetconLowering.ResumePrototype->getFunctionType()->dump();
#endif
report_fatal_error("wrong number of results from coro.suspend.retcon");
}
for (size_t I = 0, E = ResumeTys.size(); I != E; ++I) {
if (SuspendResultTys[I] != ResumeTys[I]) {
#ifndef NDEBUG
Suspend->dump();
Prototype->getFunctionType()->dump();
RetconLowering.ResumePrototype->getFunctionType()->dump();
#endif
report_fatal_error("result from coro.suspend.retcon does not "
"match corresponding prototype function param");
Expand All @@ -453,26 +472,25 @@ void coro::Shape::buildFrom(Function &F) {
}
break;
}

default:
llvm_unreachable("coro.begin is not dependent on a coro.id call");
}
}

// The coro.free intrinsic is always lowered to the result of coro.begin.
void coro::Shape::cleanCoroutine(
SmallVectorImpl<CoroFrameInst *> &CoroFrames,
SmallVectorImpl<CoroSaveInst *> &UnusedCoroSaves) {
// The coro.frame intrinsic is always lowered to the result of coro.begin.
for (CoroFrameInst *CF : CoroFrames) {
CF->replaceAllUsesWith(CoroBegin);
CF->eraseFromParent();
}

// Move final suspend to be the last element in the CoroSuspends vector.
if (ABI == coro::ABI::Switch &&
SwitchLowering.HasFinalSuspend &&
FinalSuspendIndex != CoroSuspends.size() - 1)
std::swap(CoroSuspends[FinalSuspendIndex], CoroSuspends.back());
CoroFrames.clear();

// Remove orphaned coro.saves.
for (CoroSaveInst *CoroSave : UnusedCoroSaves)
CoroSave->eraseFromParent();
UnusedCoroSaves.clear();
}

static void propagateCallAttrsFromCallee(CallInst *Call, Function *Callee) {
Expand Down
Loading