Skip to content

[SimpleLoopUnswitch] Remove callbacks #73300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 24, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 76 additions & 88 deletions llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1682,13 +1682,12 @@ deleteDeadClonedBlocks(Loop &L, ArrayRef<BasicBlock *> ExitBlocks,
BB->eraseFromParent();
}

static void
deleteDeadBlocksFromLoop(Loop &L,
SmallVectorImpl<BasicBlock *> &ExitBlocks,
DominatorTree &DT, LoopInfo &LI,
MemorySSAUpdater *MSSAU,
ScalarEvolution *SE,
function_ref<void(Loop &, StringRef)> DestroyLoopCB) {
static void deleteDeadBlocksFromLoop(Loop &L,
SmallVectorImpl<BasicBlock *> &ExitBlocks,
DominatorTree &DT, LoopInfo &LI,
MemorySSAUpdater *MSSAU,
ScalarEvolution *SE,
LPMUpdater &LoopUpdater) {
// Find all the dead blocks tied to this loop, and remove them from their
// successors.
SmallSetVector<BasicBlock *, 8> DeadBlockSet;
Expand Down Expand Up @@ -1738,7 +1737,7 @@ deleteDeadBlocksFromLoop(Loop &L,
}) &&
"If the child loop header is dead all blocks in the child loop must "
"be dead as well!");
DestroyLoopCB(*ChildL, ChildL->getName());
LoopUpdater.markLoopAsDeleted(*ChildL, ChildL->getName());
if (SE)
SE->forgetBlockAndLoopDispositions();
LI.destroy(ChildL);
Expand Down Expand Up @@ -2082,8 +2081,8 @@ static bool rebuildLoopAfterUnswitch(Loop &L, ArrayRef<BasicBlock *> ExitBlocks,
ParentL->removeChildLoop(llvm::find(*ParentL, &L));
else
LI.removeLoop(llvm::find(LI, &L));
// markLoopAsDeleted for L should be triggered by the caller (it is typically
// done by using the UnswitchCB callback).
// markLoopAsDeleted for L should be triggered by the caller (it is
// typically done within postUnswitch).
if (SE)
SE->forgetBlockAndLoopDispositions();
LI.destroy(&L);
Expand Down Expand Up @@ -2120,18 +2119,56 @@ void visitDomSubTree(DominatorTree &DT, BasicBlock *BB, CallableT Callable) {
} while (!DomWorklist.empty());
}

void postUnswitch(Loop &L, LPMUpdater &U, StringRef LoopName,
bool CurrentLoopValid, bool PartiallyInvariant,
bool InjectedCondition, ArrayRef<Loop *> NewLoops) {
// If we did a non-trivial unswitch, we have added new (cloned) loops.
if (!NewLoops.empty())
U.addSiblingLoops(NewLoops);

// If the current loop remains valid, we should revisit it to catch any
// other unswitch opportunities. Otherwise, we need to mark it as deleted.
if (CurrentLoopValid) {
if (PartiallyInvariant) {
// Mark the new loop as partially unswitched, to avoid unswitching on
// the same condition again.
auto &Context = L.getHeader()->getContext();
MDNode *DisableUnswitchMD = MDNode::get(
Context,
MDString::get(Context, "llvm.loop.unswitch.partial.disable"));
MDNode *NewLoopID = makePostTransformationMetadata(
Context, L.getLoopID(), {"llvm.loop.unswitch.partial"},
{DisableUnswitchMD});
L.setLoopID(NewLoopID);
} else if (InjectedCondition) {
// Do the same for injection of invariant conditions.
auto &Context = L.getHeader()->getContext();
MDNode *DisableUnswitchMD = MDNode::get(
Context,
MDString::get(Context, "llvm.loop.unswitch.injection.disable"));
MDNode *NewLoopID = makePostTransformationMetadata(
Context, L.getLoopID(), {"llvm.loop.unswitch.injection"},
{DisableUnswitchMD});
L.setLoopID(NewLoopID);
} else
U.revisitCurrentLoop();
} else
U.markLoopAsDeleted(L, LoopName);
}

static void unswitchNontrivialInvariants(
Loop &L, Instruction &TI, ArrayRef<Value *> Invariants,
IVConditionInfo &PartialIVInfo, DominatorTree &DT, LoopInfo &LI,
AssumptionCache &AC,
function_ref<void(bool, bool, bool, ArrayRef<Loop *>)> UnswitchCB,
ScalarEvolution *SE, MemorySSAUpdater *MSSAU,
function_ref<void(Loop &, StringRef)> DestroyLoopCB, bool InsertFreeze,
bool InjectedCondition) {
AssumptionCache &AC, ScalarEvolution *SE, MemorySSAUpdater *MSSAU,
LPMUpdater &LoopUpdater, bool InsertFreeze, bool InjectedCondition) {
auto *ParentBB = TI.getParent();
BranchInst *BI = dyn_cast<BranchInst>(&TI);
SwitchInst *SI = BI ? nullptr : cast<SwitchInst>(&TI);

// Save the current loop name in a variable so that we can report it even
// after it has been deleted.
std::string LoopName(L.getName());

// We can only unswitch switches, conditional branches with an invariant
// condition, or combining invariant conditions with an instruction or
// partially invariant instructions.
Expand Down Expand Up @@ -2444,7 +2481,7 @@ static void unswitchNontrivialInvariants(
// Now that our cloned loops have been built, we can update the original loop.
// First we delete the dead blocks from it and then we rebuild the loop
// structure taking these deletions into account.
deleteDeadBlocksFromLoop(L, ExitBlocks, DT, LI, MSSAU, SE,DestroyLoopCB);
deleteDeadBlocksFromLoop(L, ExitBlocks, DT, LI, MSSAU, SE, LoopUpdater);

if (MSSAU && VerifyMemorySSA)
MSSAU->getMemorySSA()->verifyMemorySSA();
Expand Down Expand Up @@ -2580,7 +2617,8 @@ static void unswitchNontrivialInvariants(
for (Loop *UpdatedL : llvm::concat<Loop *>(NonChildClonedLoops, HoistedLoops))
if (UpdatedL->getParentLoop() == ParentL)
SibLoops.push_back(UpdatedL);
UnswitchCB(IsStillLoop, PartiallyInvariant, InjectedCondition, SibLoops);
postUnswitch(L, LoopUpdater, LoopName, IsStillLoop, PartiallyInvariant,
InjectedCondition, SibLoops);

if (MSSAU && VerifyMemorySSA)
MSSAU->getMemorySSA()->verifyMemorySSA();
Expand Down Expand Up @@ -3427,12 +3465,11 @@ static bool shouldInsertFreeze(Loop &L, Instruction &TI, DominatorTree &DT,
Cond, &AC, L.getLoopPreheader()->getTerminator(), &DT);
}

static bool unswitchBestCondition(
Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
AAResults &AA, TargetTransformInfo &TTI,
function_ref<void(bool, bool, bool, ArrayRef<Loop *>)> UnswitchCB,
ScalarEvolution *SE, MemorySSAUpdater *MSSAU,
function_ref<void(Loop &, StringRef)> DestroyLoopCB) {
static bool unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI,
AssumptionCache &AC, AAResults &AA,
TargetTransformInfo &TTI, ScalarEvolution *SE,
MemorySSAUpdater *MSSAU,
LPMUpdater &LoopUpdater) {
// Collect all invariant conditions within this loop (as opposed to an inner
// loop which would be handled when visiting that inner loop).
SmallVector<NonTrivialUnswitchCandidate, 4> UnswitchCandidates;
Expand Down Expand Up @@ -3495,8 +3532,8 @@ static bool unswitchBestCondition(
LLVM_DEBUG(dbgs() << " Unswitching non-trivial (cost = " << Best.Cost
<< ") terminator: " << *Best.TI << "\n");
unswitchNontrivialInvariants(L, *Best.TI, Best.Invariants, PartialIVInfo, DT,
LI, AC, UnswitchCB, SE, MSSAU, DestroyLoopCB,
InsertFreeze, InjectedCondition);
LI, AC, SE, MSSAU, LoopUpdater, InsertFreeze,
InjectedCondition);
return true;
}

Expand All @@ -3515,20 +3552,18 @@ static bool unswitchBestCondition(
/// true, we will attempt to do non-trivial unswitching as well as trivial
/// unswitching.
///
/// The `UnswitchCB` callback provided will be run after unswitching is
/// complete, with the first parameter set to `true` if the provided loop
/// remains a loop, and a list of new sibling loops created.
/// The `postUnswitch` function will be run after unswitching is complete
/// with information on whether or not the provided loop remains a loop and
/// a list of new sibling loops created.
///
/// If `SE` is non-null, we will update that analysis based on the unswitching
/// done.
static bool
unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
AAResults &AA, TargetTransformInfo &TTI, bool Trivial,
bool NonTrivial,
function_ref<void(bool, bool, bool, ArrayRef<Loop *>)> UnswitchCB,
ScalarEvolution *SE, MemorySSAUpdater *MSSAU,
ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI,
function_ref<void(Loop &, StringRef)> DestroyLoopCB) {
static bool unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI,
AssumptionCache &AC, AAResults &AA,
TargetTransformInfo &TTI, bool Trivial,
bool NonTrivial, ScalarEvolution *SE,
MemorySSAUpdater *MSSAU, ProfileSummaryInfo *PSI,
BlockFrequencyInfo *BFI, LPMUpdater &LoopUpdater) {
assert(L.isRecursivelyLCSSAForm(DT, LI) &&
"Loops must be in LCSSA form before unswitching.");

Expand All @@ -3540,8 +3575,9 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
if (Trivial && unswitchAllTrivialConditions(L, DT, LI, SE, MSSAU)) {
// If we unswitched successfully we will want to clean up the loop before
// processing it further so just mark it as unswitched and return.
UnswitchCB(/*CurrentLoopValid*/ true, /*PartiallyInvariant*/ false,
/*InjectedCondition*/ false, {});
postUnswitch(L, LoopUpdater, L.getName(),
/*CurrentLoopValid*/ true, /*PartiallyInvariant*/ false,
/*InjectedCondition*/ false, {});
return true;
}

Expand Down Expand Up @@ -3610,8 +3646,7 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,

// Try to unswitch the best invariant condition. We prefer this full unswitch to
// a partial unswitch when possible below the threshold.
if (unswitchBestCondition(L, DT, LI, AC, AA, TTI, UnswitchCB, SE, MSSAU,
DestroyLoopCB))
if (unswitchBestCondition(L, DT, LI, AC, AA, TTI, SE, MSSAU, LoopUpdater))
return true;

// No other opportunities to unswitch.
Expand All @@ -3631,61 +3666,14 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM,
LLVM_DEBUG(dbgs() << "Unswitching loop in " << F.getName() << ": " << L
<< "\n");

// Save the current loop name in a variable so that we can report it even
// after it has been deleted.
std::string LoopName = std::string(L.getName());

auto UnswitchCB = [&L, &U, &LoopName](bool CurrentLoopValid,
bool PartiallyInvariant,
bool InjectedCondition,
ArrayRef<Loop *> NewLoops) {
// If we did a non-trivial unswitch, we have added new (cloned) loops.
if (!NewLoops.empty())
U.addSiblingLoops(NewLoops);

// If the current loop remains valid, we should revisit it to catch any
// other unswitch opportunities. Otherwise, we need to mark it as deleted.
if (CurrentLoopValid) {
if (PartiallyInvariant) {
// Mark the new loop as partially unswitched, to avoid unswitching on
// the same condition again.
auto &Context = L.getHeader()->getContext();
MDNode *DisableUnswitchMD = MDNode::get(
Context,
MDString::get(Context, "llvm.loop.unswitch.partial.disable"));
MDNode *NewLoopID = makePostTransformationMetadata(
Context, L.getLoopID(), {"llvm.loop.unswitch.partial"},
{DisableUnswitchMD});
L.setLoopID(NewLoopID);
} else if (InjectedCondition) {
// Do the same for injection of invariant conditions.
auto &Context = L.getHeader()->getContext();
MDNode *DisableUnswitchMD = MDNode::get(
Context,
MDString::get(Context, "llvm.loop.unswitch.injection.disable"));
MDNode *NewLoopID = makePostTransformationMetadata(
Context, L.getLoopID(), {"llvm.loop.unswitch.injection"},
{DisableUnswitchMD});
L.setLoopID(NewLoopID);
} else
U.revisitCurrentLoop();
} else
U.markLoopAsDeleted(L, LoopName);
};

auto DestroyLoopCB = [&U](Loop &L, StringRef Name) {
U.markLoopAsDeleted(L, Name);
};

std::optional<MemorySSAUpdater> MSSAU;
if (AR.MSSA) {
MSSAU = MemorySSAUpdater(AR.MSSA);
if (VerifyMemorySSA)
AR.MSSA->verifyMemorySSA();
}
if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.AA, AR.TTI, Trivial, NonTrivial,
UnswitchCB, &AR.SE, MSSAU ? &*MSSAU : nullptr, PSI, AR.BFI,
DestroyLoopCB))
&AR.SE, MSSAU ? &*MSSAU : nullptr, PSI, AR.BFI, U))
return PreservedAnalyses::all();

if (AR.MSSA && VerifyMemorySSA)
Expand Down