Skip to content

[Coroutines] ABI Objects to improve code separation between different ABIs, users and utilities. #109713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions llvm/include/llvm/Transforms/Coroutines/CoroSplit.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,26 @@

namespace llvm {

namespace coro {
class BaseABI;
class Shape;
} // namespace coro

struct CoroSplitPass : PassInfoMixin<CoroSplitPass> {
const std::function<bool(Instruction &)> MaterializableCallback;

CoroSplitPass(bool OptimizeFrame = false);
CoroSplitPass(std::function<bool(Instruction &)> MaterializableCallback,
bool OptimizeFrame = false)
: MaterializableCallback(MaterializableCallback),
OptimizeFrame(OptimizeFrame) {}
bool OptimizeFrame = false);

PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM,
LazyCallGraph &CG, CGSCCUpdateResult &UR);
static bool isRequired() { return true; }

using BaseABITy =
std::function<std::unique_ptr<coro::BaseABI>(Function &, coro::Shape &)>;
// Generator for an ABI transformer
BaseABITy CreateAndInitABI;

// Would be true if the Optimization level isn't O0.
bool OptimizeFrame;
};
Expand Down
102 changes: 102 additions & 0 deletions llvm/lib/Transforms/Coroutines/ABI.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
//===- ABI.h - Coroutine lowering class definitions (ABIs) ----*- C++ -*---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// This file defines coroutine lowering classes. The interface for coroutine
// lowering is defined by BaseABI. Each lowering method (ABI) implements the
// interface. Note that the enum class ABI, such as ABI::Switch, determines
// which ABI class, such as SwitchABI, is used to lower the coroutine. Both the
// ABI enum and ABI class are used by the Coroutine passes when lowering.
//===----------------------------------------------------------------------===//

#ifndef LIB_TRANSFORMS_COROUTINES_ABI_H
#define LIB_TRANSFORMS_COROUTINES_ABI_H

#include "CoroShape.h"
#include "SuspendCrossingInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"

namespace llvm {

class Function;

namespace coro {

// This interface/API is to provide an object oriented way to implement ABI
// functionality. This is intended to replace use of the ABI enum to perform
// ABI operations. The ABIs (e.g. Switch, Async, Retcon{Once}) are the common
// ABIs.

class LLVM_LIBRARY_VISIBILITY BaseABI {
public:
BaseABI(Function &F, coro::Shape &S,
std::function<bool(Instruction &)> IsMaterializable)
: F(F), Shape(S), IsMaterializable(IsMaterializable) {}
virtual ~BaseABI() = default;

// Initialize the coroutine ABI
virtual void init() = 0;

// Allocate the coroutine frame and do spill/reload as needed.
virtual void buildCoroutineFrame();

// Perform the function splitting according to the ABI.
virtual void splitCoroutine(Function &F, coro::Shape &Shape,
SmallVectorImpl<Function *> &Clones,
TargetTransformInfo &TTI) = 0;

Function &F;
coro::Shape &Shape;

// Callback used by coro::BaseABI::buildCoroutineFrame for rematerialization.
// It is provided to coro::doMaterializations(..).
std::function<bool(Instruction &I)> IsMaterializable;
};

class LLVM_LIBRARY_VISIBILITY SwitchABI : public BaseABI {
public:
SwitchABI(Function &F, coro::Shape &S,
std::function<bool(Instruction &)> IsMaterializable)
: BaseABI(F, S, IsMaterializable) {}

void init() override;

void splitCoroutine(Function &F, coro::Shape &Shape,
SmallVectorImpl<Function *> &Clones,
TargetTransformInfo &TTI) override;
};

class LLVM_LIBRARY_VISIBILITY AsyncABI : public BaseABI {
public:
AsyncABI(Function &F, coro::Shape &S,
std::function<bool(Instruction &)> IsMaterializable)
: BaseABI(F, S, IsMaterializable) {}

void init() override;

void splitCoroutine(Function &F, coro::Shape &Shape,
SmallVectorImpl<Function *> &Clones,
TargetTransformInfo &TTI) override;
};

class LLVM_LIBRARY_VISIBILITY AnyRetconABI : public BaseABI {
public:
AnyRetconABI(Function &F, coro::Shape &S,
std::function<bool(Instruction &)> IsMaterializable)
: BaseABI(F, S, IsMaterializable) {}

void init() override;

void splitCoroutine(Function &F, coro::Shape &Shape,
SmallVectorImpl<Function *> &Clones,
TargetTransformInfo &TTI) override;
};

} // end namespace coro

} // end namespace llvm

#endif // LLVM_TRANSFORMS_COROUTINES_ABI_H
7 changes: 3 additions & 4 deletions llvm/lib/Transforms/Coroutines/CoroFrame.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// the value into the coroutine frame.
//===----------------------------------------------------------------------===//

#include "ABI.h"
#include "CoroInternal.h"
#include "MaterializationUtils.h"
#include "SpillUtils.h"
Expand Down Expand Up @@ -2055,11 +2056,9 @@ void coro::normalizeCoroutine(Function &F, coro::Shape &Shape,
rewritePHIs(F);
}

void coro::buildCoroutineFrame(
Function &F, Shape &Shape,
const std::function<bool(Instruction &)> &MaterializableCallback) {
void coro::BaseABI::buildCoroutineFrame() {
SuspendCrossingInfo Checker(F, Shape.CoroSuspends, Shape.CoroEnds);
doRematerializations(F, Checker, MaterializableCallback);
doRematerializations(F, Checker, IsMaterializable);

const DominatorTree DT(F);
if (Shape.ABI != coro::ABI::Async && Shape.ABI != coro::ABI::Retcon &&
Expand Down
3 changes: 0 additions & 3 deletions llvm/lib/Transforms/Coroutines/CoroInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ struct LowererBase {
bool defaultMaterializable(Instruction &V);
void normalizeCoroutine(Function &F, coro::Shape &Shape,
TargetTransformInfo &TTI);
void buildCoroutineFrame(
Function &F, Shape &Shape,
const std::function<bool(Instruction &)> &MaterializableCallback);
CallInst *createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
TargetTransformInfo &TTI,
ArrayRef<Value *> Arguments, IRBuilder<> &);
Expand Down
1 change: 0 additions & 1 deletion llvm/lib/Transforms/Coroutines/CoroShape.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,6 @@ struct LLVM_LIBRARY_VISIBILITY Shape {
invalidateCoroutine(F, CoroFrames);
return;
}
initABI();
cleanCoroutine(CoroFrames, UnusedCoroSaves);
}
};
Expand Down
109 changes: 73 additions & 36 deletions llvm/lib/Transforms/Coroutines/CoroSplit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Coroutines/CoroSplit.h"
#include "ABI.h"
#include "CoroInstr.h"
#include "CoroInternal.h"
#include "MaterializationUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/PriorityWorklist.h"
#include "llvm/ADT/SmallPtrSet.h"
Expand Down Expand Up @@ -1779,9 +1781,9 @@ CallInst *coro::createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
return TailCall;
}

static void splitAsyncCoroutine(Function &F, coro::Shape &Shape,
SmallVectorImpl<Function *> &Clones,
TargetTransformInfo &TTI) {
void coro::AsyncABI::splitCoroutine(Function &F, coro::Shape &Shape,
SmallVectorImpl<Function *> &Clones,
TargetTransformInfo &TTI) {
assert(Shape.ABI == coro::ABI::Async);
assert(Clones.empty());
// Reset various things that the optimizer might have decided it
Expand Down Expand Up @@ -1874,9 +1876,9 @@ static void splitAsyncCoroutine(Function &F, coro::Shape &Shape,
}
}

static void splitRetconCoroutine(Function &F, coro::Shape &Shape,
SmallVectorImpl<Function *> &Clones,
TargetTransformInfo &TTI) {
void coro::AnyRetconABI::splitCoroutine(Function &F, coro::Shape &Shape,
SmallVectorImpl<Function *> &Clones,
TargetTransformInfo &TTI) {
assert(Shape.ABI == coro::ABI::Retcon || Shape.ABI == coro::ABI::RetconOnce);
assert(Clones.empty());

Expand Down Expand Up @@ -2044,26 +2046,27 @@ static bool hasSafeElideCaller(Function &F) {
return false;
}

static coro::Shape
splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
TargetTransformInfo &TTI, bool OptimizeFrame,
std::function<bool(Instruction &)> MaterializableCallback) {
PrettyStackTraceFunction prettyStackTrace(F);
void coro::SwitchABI::splitCoroutine(Function &F, coro::Shape &Shape,
SmallVectorImpl<Function *> &Clones,
TargetTransformInfo &TTI) {
SwitchCoroutineSplitter::split(F, Shape, Clones, TTI);
}

// The suspend-crossing algorithm in buildCoroutineFrame get tripped
// up by uses in unreachable blocks, so remove them as a first pass.
removeUnreachableBlocks(F);
static void doSplitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
coro::BaseABI &ABI, TargetTransformInfo &TTI) {
PrettyStackTraceFunction prettyStackTrace(F);

coro::Shape Shape(F, OptimizeFrame);
if (!Shape.CoroBegin)
return Shape;
auto &Shape = ABI.Shape;
assert(Shape.CoroBegin);

lowerAwaitSuspends(F, Shape);

simplifySuspendPoints(Shape);

normalizeCoroutine(F, Shape, TTI);
buildCoroutineFrame(F, Shape, MaterializableCallback);
ABI.buildCoroutineFrame();
replaceFrameSizeAndAlignment(Shape);

bool isNoSuspendCoroutine = Shape.CoroSuspends.empty();

bool shouldCreateNoAllocVariant = !isNoSuspendCoroutine &&
Expand All @@ -2075,18 +2078,7 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
if (isNoSuspendCoroutine) {
handleNoSuspendCoroutine(Shape);
} else {
switch (Shape.ABI) {
case coro::ABI::Switch:
SwitchCoroutineSplitter::split(F, Shape, Clones, TTI);
break;
case coro::ABI::Async:
splitAsyncCoroutine(F, Shape, Clones, TTI);
break;
case coro::ABI::Retcon:
case coro::ABI::RetconOnce:
splitRetconCoroutine(F, Shape, Clones, TTI);
break;
}
ABI.splitCoroutine(F, Shape, Clones, TTI);
}

// Replace all the swifterror operations in the original function.
Expand All @@ -2107,8 +2099,6 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,

if (shouldCreateNoAllocVariant)
SwitchCoroutineSplitter::createNoAllocVariant(F, Shape, Clones);

return Shape;
}

static LazyCallGraph::SCC &updateCallGraphAfterCoroutineSplit(
Expand Down Expand Up @@ -2207,8 +2197,44 @@ static void addPrepareFunction(const Module &M,
Fns.push_back(PrepareFn);
}

static std::unique_ptr<coro::BaseABI>
CreateNewABI(Function &F, coro::Shape &S,
std::function<bool(Instruction &)> IsMatCallback) {
switch (S.ABI) {
case coro::ABI::Switch:
return std::unique_ptr<coro::BaseABI>(
new coro::SwitchABI(F, S, IsMatCallback));
case coro::ABI::Async:
return std::unique_ptr<coro::BaseABI>(
new coro::AsyncABI(F, S, IsMatCallback));
case coro::ABI::Retcon:
return std::unique_ptr<coro::BaseABI>(
new coro::AnyRetconABI(F, S, IsMatCallback));
case coro::ABI::RetconOnce:
return std::unique_ptr<coro::BaseABI>(
new coro::AnyRetconABI(F, S, IsMatCallback));
}
llvm_unreachable("Unknown ABI");
}

CoroSplitPass::CoroSplitPass(bool OptimizeFrame)
: MaterializableCallback(coro::defaultMaterializable),
: CreateAndInitABI([](Function &F, coro::Shape &S) {
std::unique_ptr<coro::BaseABI> ABI =
CreateNewABI(F, S, coro::isTriviallyMaterializable);
ABI->init();
return std::move(ABI);
}),
OptimizeFrame(OptimizeFrame) {}

// For back compatibility, constructor takes a materializable callback and
// creates a generator for an ABI with a modified materializable callback.
CoroSplitPass::CoroSplitPass(std::function<bool(Instruction &)> IsMatCallback,
bool OptimizeFrame)
: CreateAndInitABI([=](Function &F, coro::Shape &S) {
std::unique_ptr<coro::BaseABI> ABI = CreateNewABI(F, S, IsMatCallback);
ABI->init();
return std::move(ABI);
}),
OptimizeFrame(OptimizeFrame) {}

PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C,
Expand Down Expand Up @@ -2241,12 +2267,23 @@ PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C,
Function &F = N->getFunction();
LLVM_DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F.getName()
<< "\n");

// The suspend-crossing algorithm in buildCoroutineFrame gets tripped up
// by unreachable blocks, so remove them as a first pass. Remove the
// unreachable blocks before collecting intrinsics into Shape.
removeUnreachableBlocks(F);

coro::Shape Shape(F, OptimizeFrame);
if (!Shape.CoroBegin)
continue;

F.setSplittedCoroutine();

std::unique_ptr<coro::BaseABI> ABI = CreateAndInitABI(F, Shape);

SmallVector<Function *, 4> Clones;
coro::Shape Shape =
splitCoroutine(F, Clones, FAM.getResult<TargetIRAnalysis>(F),
OptimizeFrame, MaterializableCallback);
auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
doSplitCoroutine(F, Clones, *ABI, TTI);
CurrentSCC = &updateCallGraphAfterCoroutineSplit(
*N, Shape, Clones, *CurrentSCC, CG, AM, UR, FAM);

Expand Down
Loading
Loading