Skip to content

Commit 446e104

Browse files
committed
[CodeGenPrepare] Folding urem with loop invariant value
``` for(i = Start; i < End; ++i) Rem = (i nuw+ IncrLoopInvariant) u% RemAmtLoopInvariant; ``` -> ``` Rem = (Start nuw+ IncrLoopInvariant) % RemAmtLoopInvariant; for(i = Start; i < End; ++i, ++rem) Rem = rem == RemAmtLoopInvariant ? 0 : Rem; ``` In its current state, only if `IncrLoopInvariant` and `Start` both being zero. Alive2 seemed unable to prove this (see: https://alive2.llvm.org/ce/z/ATGDp3 which is clearly wrong but still checks out...) so wrote an exhaustive test here: https://godbolt.org/z/WYa561388
1 parent e2e2fbf commit 446e104

File tree

2 files changed

+223
-43
lines changed

2 files changed

+223
-43
lines changed

llvm/lib/CodeGen/CodeGenPrepare.cpp

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,7 @@ class CodeGenPrepare {
471471
bool replaceMathCmpWithIntrinsic(BinaryOperator *BO, Value *Arg0, Value *Arg1,
472472
CmpInst *Cmp, Intrinsic::ID IID);
473473
bool optimizeCmp(CmpInst *Cmp, ModifyDT &ModifiedDT);
474+
bool optimizeRem(Instruction *Rem);
474475
bool combineToUSubWithOverflow(CmpInst *Cmp, ModifyDT &ModifiedDT);
475476
bool combineToUAddWithOverflow(CmpInst *Cmp, ModifyDT &ModifiedDT);
476477
void verifyBFIUpdates(Function &F);
@@ -1974,6 +1975,160 @@ static bool foldFCmpToFPClassTest(CmpInst *Cmp, const TargetLowering &TLI,
19741975
return true;
19751976
}
19761977

1978+
static bool isRemOfLoopIncrementWithLoopInvariant(
1979+
Value *Rem, const LoopInfo *LI, Value *&RemAmtOut,
1980+
std::optional<bool> &AddOrSubOut, Value *&AddOrSubOffsetOut,
1981+
PHINode *&LoopIncrPNOut) {
1982+
Value *Incr, *RemAmt;
1983+
if (!isa<Instruction>(Rem))
1984+
return false;
1985+
// NB: If RemAmt is a power of 2 it *should* have been transformed by now.
1986+
if (!match(Rem, m_URem(m_Value(Incr), m_Value(RemAmt))))
1987+
return false;
1988+
1989+
// Only trivially analyzable loops.
1990+
Loop *L = LI->getLoopFor(cast<Instruction>(Rem)->getParent());
1991+
if (L == nullptr || L->getLoopPreheader() == nullptr ||
1992+
L->getLoopLatch() == nullptr)
1993+
return false;
1994+
1995+
std::optional<bool> AddOrSub;
1996+
Value *AddOrSubOffset;
1997+
// Find out loop increment PHI.
1998+
PHINode *PN = dyn_cast<PHINode>(Incr);
1999+
if (PN != nullptr) {
2000+
AddOrSub = std::nullopt;
2001+
AddOrSubOffset = nullptr;
2002+
} else {
2003+
// Search through a NUW add/sub.
2004+
Value *V0, *V1;
2005+
if (match(Incr, m_NUWAddLike(m_Value(V0), m_Value(V1))))
2006+
AddOrSub = true;
2007+
else if (match(Incr, m_NUWSub(m_Value(V0), m_Value(V1))))
2008+
AddOrSub = false;
2009+
else
2010+
return false;
2011+
2012+
PN = dyn_cast<PHINode>(V0);
2013+
if (PN != nullptr) {
2014+
AddOrSubOffset = V1;
2015+
} else if (*AddOrSub) {
2016+
PN = dyn_cast<PHINode>(V1);
2017+
AddOrSubOffset = V0;
2018+
}
2019+
}
2020+
2021+
if (PN == nullptr)
2022+
return false;
2023+
2024+
// This isn't strictly necessary, what we really need is one increment and any
2025+
// amount of initial values all being the same.
2026+
if (PN->getNumIncomingValues() != 2)
2027+
return false;
2028+
2029+
// Only works if the remainder amount is a loop invaraint
2030+
if (!L->isLoopInvariant(RemAmt))
2031+
return false;
2032+
2033+
// Is the PHI a loop increment?
2034+
auto LoopIncrInfo = getIVIncrement(PN, LI);
2035+
if (!LoopIncrInfo.has_value())
2036+
return false;
2037+
2038+
// We need remainder_amount % increment_amount to be zero. Increment of one
2039+
// satisfies that without any special logic and is overwhelmingly the common
2040+
// case.
2041+
if (!match(LoopIncrInfo->second, m_One()))
2042+
return false;
2043+
2044+
// Need the increment to not overflow.
2045+
if (!match(LoopIncrInfo->first, m_NUWAdd(m_Value(), m_Value())))
2046+
return false;
2047+
2048+
if (PN->getBasicBlockIndex(L->getLoopLatch()) < 0 ||
2049+
PN->getBasicBlockIndex(L->getLoopPreheader()) < 0)
2050+
return false;
2051+
2052+
// Set output variables.
2053+
RemAmtOut = RemAmt;
2054+
LoopIncrPNOut = PN;
2055+
AddOrSubOut = AddOrSub;
2056+
AddOrSubOffsetOut = AddOrSubOffset;
2057+
2058+
return true;
2059+
}
2060+
2061+
// Try to transform:
2062+
//
2063+
// for(i = Start; i < End; ++i)
2064+
// Rem = (i nuw+ IncrLoopInvariant) u% RemAmtLoopInvariant;
2065+
//
2066+
// ->
2067+
//
2068+
// Rem = (Start nuw+ IncrLoopInvariant) % RemAmtLoopInvariant;
2069+
// for(i = Start; i < End; ++i, ++rem)
2070+
// Rem = rem == RemAmtLoopInvariant ? 0 : Rem;
2071+
//
2072+
// Currently only implemented for `Start` and `IncrLoopInvariant` being zero.
2073+
static bool foldURemOfLoopIncrement(Instruction *Rem, const LoopInfo *LI,
2074+
SmallSet<BasicBlock *, 32> &FreshBBs,
2075+
bool IsHuge) {
2076+
std::optional<bool> AddOrSub;
2077+
Value *AddOrSubOffset, *RemAmt;
2078+
PHINode *LoopIncrPN;
2079+
if (!isRemOfLoopIncrementWithLoopInvariant(Rem, LI, RemAmt, AddOrSub,
2080+
AddOrSubOffset, LoopIncrPN))
2081+
return false;
2082+
2083+
// Only non-constant remainder as the extra IV is is probably not profitable
2084+
// in that case. Further, since remainder amount is non-constant, only handle
2085+
// case where `IncrLoopInvariant` and `Start` are 0 to entirely eliminate the
2086+
// rem (as opposed to just hoisting it outside of the loop).
2087+
//
2088+
// Potential TODO: Should we have a check for how "nested" this remainder
2089+
// operation is? The new code runs every iteration so if the remainder is
2090+
// guarded behind unlikely conditions this might not be worth it.
2091+
if (AddOrSub.has_value() || match(RemAmt, m_ImmConstant()))
2092+
return false;
2093+
Loop *L = LI->getLoopFor(Rem->getParent());
2094+
if (!match(LoopIncrPN->getIncomingValueForBlock(L->getLoopPreheader()),
2095+
m_Zero()))
2096+
return false;
2097+
2098+
// Create new remainder with induction variable.
2099+
Type *Ty = Rem->getType();
2100+
IRBuilder<> Builder(Rem->getContext());
2101+
2102+
Builder.SetInsertPoint(LoopIncrPN);
2103+
PHINode *NewRem = Builder.CreatePHI(Ty, 2);
2104+
2105+
Builder.SetInsertPoint(cast<Instruction>(
2106+
LoopIncrPN->getIncomingValueForBlock(L->getLoopLatch())));
2107+
// `(add (urem x, y), 1)` is always nuw.
2108+
Value *RemAdd = Builder.CreateNUWAdd(NewRem, ConstantInt::get(Ty, 1));
2109+
Value *RemCmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, RemAdd, RemAmt);
2110+
Value *RemSel =
2111+
Builder.CreateSelect(RemCmp, Constant::getNullValue(Ty), RemAdd);
2112+
2113+
NewRem->addIncoming(Constant::getNullValue(Ty), L->getLoopPreheader());
2114+
NewRem->addIncoming(RemSel, L->getLoopLatch());
2115+
2116+
// Insert all touched BBs.
2117+
FreshBBs.insert(LoopIncrPN->getParent());
2118+
FreshBBs.insert(L->getLoopLatch());
2119+
FreshBBs.insert(Rem->getParent());
2120+
2121+
replaceAllUsesWith(Rem, NewRem, FreshBBs, IsHuge);
2122+
Rem->eraseFromParent();
2123+
return true;
2124+
}
2125+
2126+
bool CodeGenPrepare::optimizeRem(Instruction *Rem) {
2127+
if (foldURemOfLoopIncrement(Rem, LI, FreshBBs, IsHugeFunc))
2128+
return true;
2129+
return false;
2130+
}
2131+
19772132
bool CodeGenPrepare::optimizeCmp(CmpInst *Cmp, ModifyDT &ModifiedDT) {
19782133
if (sinkCmpExpression(Cmp, *TLI))
19792134
return true;
@@ -8360,6 +8515,11 @@ bool CodeGenPrepare::optimizeInst(Instruction *I, ModifyDT &ModifiedDT) {
83608515
if (optimizeCmp(Cmp, ModifiedDT))
83618516
return true;
83628517

8518+
if (match(I, m_URem(m_Value(), m_Value())) ||
8519+
match(I, m_SRem(m_Value(), m_Value())))
8520+
if (optimizeRem(I))
8521+
return true;
8522+
83638523
if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
83648524
LI->setMetadata(LLVMContext::MD_invariant_group, nullptr);
83658525
bool Modified = optimizeLoadExt(LI);

0 commit comments

Comments
 (0)