Skip to content

Commit 616d93c

Browse files
committed
[CodeGenPrepare] Folding urem with loop invariant value
``` for(i = Start; i < End; ++i) Rem = (i nuw+ IncrLoopInvariant) u% RemAmtLoopInvariant; ``` -> ``` Rem = (Start nuw+ IncrLoopInvariant) % RemAmtLoopInvariant; for(i = Start; i < End; ++i, ++rem) Rem = rem == RemAmtLoopInvariant ? 0 : Rem; ``` In its current state, only if `IncrLoopInvariant` and `Start` both being zero. Alive2 seemed unable to prove this (see: https://alive2.llvm.org/ce/z/ATGDp3 which is clearly wrong but still checks out...) so wrote an exhaustive test here: https://godbolt.org/z/WYa561388
1 parent c94889e commit 616d93c

File tree

2 files changed

+220
-43
lines changed

2 files changed

+220
-43
lines changed

llvm/lib/CodeGen/CodeGenPrepare.cpp

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,7 @@ class CodeGenPrepare {
471471
bool replaceMathCmpWithIntrinsic(BinaryOperator *BO, Value *Arg0, Value *Arg1,
472472
CmpInst *Cmp, Intrinsic::ID IID);
473473
bool optimizeCmp(CmpInst *Cmp, ModifyDT &ModifiedDT);
474+
bool optimizeRem(Instruction *Rem);
474475
bool combineToUSubWithOverflow(CmpInst *Cmp, ModifyDT &ModifiedDT);
475476
bool combineToUAddWithOverflow(CmpInst *Cmp, ModifyDT &ModifiedDT);
476477
void verifyBFIUpdates(Function &F);
@@ -1974,6 +1975,157 @@ static bool foldFCmpToFPClassTest(CmpInst *Cmp, const TargetLowering &TLI,
19741975
return true;
19751976
}
19761977

1978+
static bool isRemOfLoopIncrementWithLIV(Value *Rem, const LoopInfo *LI,
1979+
Value *&RemAmtOut,
1980+
std::optional<bool> &AddOrSubOut,
1981+
Value *&AddOrSubOffsetOut,
1982+
PHINode *&LoopIncrPNOut) {
1983+
Value *Incr, *RemAmt;
1984+
if (!isa<Instruction>(Rem))
1985+
return false;
1986+
// NB: If RemAmt is a power of 2 it *should* have been transformed by now.
1987+
if (!match(Rem, m_URem(m_Value(Incr), m_Value(RemAmt))))
1988+
return false;
1989+
1990+
// Only trivially analyzable loops.
1991+
Loop *L = LI->getLoopFor(cast<Instruction>(Rem)->getParent());
1992+
if (L == nullptr || L->getLoopPreheader() == nullptr ||
1993+
L->getLoopLatch() == nullptr)
1994+
return false;
1995+
1996+
std::optional<bool> AddOrSub;
1997+
Value *AddOrSubOffset;
1998+
// Find out loop increment PHI.
1999+
PHINode *PN = dyn_cast<PHINode>(Incr);
2000+
if (PN != nullptr) {
2001+
AddOrSub = std::nullopt;
2002+
AddOrSubOffset = nullptr;
2003+
} else {
2004+
// Search through a NUW add/sub.
2005+
Value *V0, *V1;
2006+
if (match(Incr, m_NUWAddLike(m_Value(V0), m_Value(V1))))
2007+
AddOrSub = true;
2008+
else if (match(Incr, m_NUWSub(m_Value(V0), m_Value(V1))))
2009+
AddOrSub = false;
2010+
else
2011+
return false;
2012+
2013+
PN = dyn_cast<PHINode>(V0);
2014+
if (PN != nullptr) {
2015+
AddOrSubOffset = V1;
2016+
} else if (*AddOrSub) {
2017+
PN = dyn_cast<PHINode>(V1);
2018+
AddOrSubOffset = V0;
2019+
}
2020+
}
2021+
2022+
if (PN == nullptr)
2023+
return false;
2024+
2025+
// This isn't strictly necessary, what we really need is one increment and any
2026+
// amount of initial values all being the same.
2027+
if (PN->getNumIncomingValues() != 2)
2028+
return false;
2029+
2030+
// Only works if the remainder amount is a loop invaraint
2031+
if (!L->isLoopInvariant(RemAmt))
2032+
return false;
2033+
2034+
// Is the PHI a loop increment?
2035+
auto LoopIncrInfo = getIVIncrement(PN, LI);
2036+
if (!LoopIncrInfo.has_value())
2037+
return false;
2038+
2039+
// We need remainder_amount % increment_amount to be zero. Increment of one
2040+
// satisfies that without any special logic and is overwhelmingly the common
2041+
// case.
2042+
if (!match(LoopIncrInfo->second, m_One()))
2043+
return false;
2044+
2045+
// Need the increment to not overflow.
2046+
if (!match(LoopIncrInfo->first, m_NUWAdd(m_Value(), m_Value())))
2047+
return false;
2048+
2049+
// Set output variables.
2050+
RemAmtOut = RemAmt;
2051+
LoopIncrPNOut = PN;
2052+
AddOrSubOut = AddOrSub;
2053+
AddOrSubOffsetOut = AddOrSubOffset;
2054+
2055+
return true;
2056+
}
2057+
2058+
// Try to transform:
2059+
//
2060+
// for(i = Start; i < End; ++i)
2061+
// Rem = (i nuw+ IncrLoopInvariant) u% RemAmtLoopInvariant;
2062+
//
2063+
// ->
2064+
//
2065+
// Rem = (Start nuw+ IncrLoopInvariant) % RemAmtLoopInvariant;
2066+
// for(i = Start; i < End; ++i, ++rem)
2067+
// Rem = rem == RemAmtLoopInvariant ? 0 : Rem;
2068+
//
2069+
// Currently only implemented for `Start` and `IncrLoopInvariant` being zero.
2070+
static bool foldURemOfLoopIncrement(Instruction *Rem, const LoopInfo *LI,
2071+
SmallSet<BasicBlock *, 32> &FreshBBs,
2072+
bool IsHuge) {
2073+
std::optional<bool> AddOrSub;
2074+
Value *AddOrSubOffset, *RemAmt;
2075+
PHINode *LoopIncrPN;
2076+
if (!isRemOfLoopIncrementWithLIV(Rem, LI, RemAmt, AddOrSub, AddOrSubOffset,
2077+
LoopIncrPN))
2078+
return false;
2079+
2080+
// Only non-constant remainder as the extra IV is is probably not profitable
2081+
// in that case. Further, since remainder amount is non-constant, only handle
2082+
// case where `IncrLoopInvariant` and `Start` are 0 to entirely eliminate the
2083+
// rem (as opposed to just hoisting it outside of the loop).
2084+
//
2085+
// Potential TODO: Should we have a check for how "nested" this remainder
2086+
// operation is? The new code runs every iteration so if the remainder is
2087+
// guarded behind unlikely conditions this might not be worth it.
2088+
if (AddOrSub.has_value() || match(RemAmt, m_ImmConstant()))
2089+
return false;
2090+
Loop *L = LI->getLoopFor(Rem->getParent());
2091+
if (!match(LoopIncrPN->getIncomingValueForBlock(L->getLoopPreheader()),
2092+
m_Zero()))
2093+
return false;
2094+
2095+
// Create new remainder with induction variable.
2096+
Type *Ty = Rem->getType();
2097+
IRBuilder<> Builder(Rem->getContext());
2098+
2099+
Builder.SetInsertPoint(LoopIncrPN);
2100+
PHINode *NewRem = Builder.CreatePHI(Ty, 2);
2101+
2102+
Builder.SetInsertPoint(cast<Instruction>(
2103+
LoopIncrPN->getIncomingValueForBlock(L->getLoopLatch())));
2104+
// `(add (urem x, y), 1)` is always nuw.
2105+
Value *RemAdd = Builder.CreateNUWAdd(NewRem, ConstantInt::get(Ty, 1));
2106+
Value *RemCmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, RemAdd, RemAmt);
2107+
Value *RemSel =
2108+
Builder.CreateSelect(RemCmp, Constant::getNullValue(Ty), RemAdd);
2109+
2110+
NewRem->addIncoming(Constant::getNullValue(Ty), L->getLoopPreheader());
2111+
NewRem->addIncoming(RemSel, L->getLoopLatch());
2112+
2113+
// Insert all touched BBs.
2114+
FreshBBs.insert(LoopIncrPN->getParent());
2115+
FreshBBs.insert(L->getLoopLatch());
2116+
FreshBBs.insert(Rem->getParent());
2117+
2118+
replaceAllUsesWith(Rem, NewRem, FreshBBs, IsHuge);
2119+
Rem->eraseFromParent();
2120+
return true;
2121+
}
2122+
2123+
bool CodeGenPrepare::optimizeRem(Instruction *Rem) {
2124+
if (foldURemOfLoopIncrement(Rem, LI, FreshBBs, IsHugeFunc))
2125+
return true;
2126+
return false;
2127+
}
2128+
19772129
bool CodeGenPrepare::optimizeCmp(CmpInst *Cmp, ModifyDT &ModifiedDT) {
19782130
if (sinkCmpExpression(Cmp, *TLI))
19792131
return true;
@@ -8360,6 +8512,11 @@ bool CodeGenPrepare::optimizeInst(Instruction *I, ModifyDT &ModifiedDT) {
83608512
if (optimizeCmp(Cmp, ModifiedDT))
83618513
return true;
83628514

8515+
if (match(I, m_URem(m_Value(), m_Value())) ||
8516+
match(I, m_SRem(m_Value(), m_Value())))
8517+
if (optimizeRem(I))
8518+
return true;
8519+
83638520
if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
83648521
LI->setMetadata(LLVMContext::MD_invariant_group, nullptr);
83658522
bool Modified = optimizeLoadExt(LI);

0 commit comments

Comments
 (0)