Skip to content

Commit e4c67ba

Browse files
committed
Recommit "[CodeGenPrepare] Folding urem with loop invariant value"
Was missing remainder on `Start` value. Also changed logic as as nikic suggested (getting loop from `PN` instead of `Rem`). The prior impl increased the complexity of the code and made debugging it more difficult. Closes #104877
1 parent 9b25ad8 commit e4c67ba

File tree

3 files changed

+267
-75
lines changed

3 files changed

+267
-75
lines changed

llvm/lib/CodeGen/CodeGenPrepare.cpp

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,7 @@ class CodeGenPrepare {
472472
bool replaceMathCmpWithIntrinsic(BinaryOperator *BO, Value *Arg0, Value *Arg1,
473473
CmpInst *Cmp, Intrinsic::ID IID);
474474
bool optimizeCmp(CmpInst *Cmp, ModifyDT &ModifiedDT);
475+
bool optimizeURem(Instruction *Rem);
475476
bool combineToUSubWithOverflow(CmpInst *Cmp, ModifyDT &ModifiedDT);
476477
bool combineToUAddWithOverflow(CmpInst *Cmp, ModifyDT &ModifiedDT);
477478
void verifyBFIUpdates(Function &F);
@@ -1975,6 +1976,135 @@ static bool foldFCmpToFPClassTest(CmpInst *Cmp, const TargetLowering &TLI,
19751976
return true;
19761977
}
19771978

1979+
static bool isRemOfLoopIncrementWithLoopInvariant(Instruction *Rem,
1980+
const LoopInfo *LI,
1981+
Value *&RemAmtOut,
1982+
PHINode *&LoopIncrPNOut) {
1983+
Value *Incr, *RemAmt;
1984+
// NB: If RemAmt is a power of 2 it *should* have been transformed by now.
1985+
if (!match(Rem, m_URem(m_Value(Incr), m_Value(RemAmt))))
1986+
return false;
1987+
1988+
// Find out loop increment PHI.
1989+
auto *PN = dyn_cast<PHINode>(Incr);
1990+
if (!PN)
1991+
return false;
1992+
1993+
// This isn't strictly necessary, what we really need is one increment and any
1994+
// amount of initial values all being the same.
1995+
if (PN->getNumIncomingValues() != 2)
1996+
return false;
1997+
1998+
// Only trivially analyzable loops.
1999+
Loop *L = LI->getLoopFor(PN->getParent());
2000+
if (!L || !L->getLoopPreheader() || !L->getLoopLatch())
2001+
return false;
2002+
2003+
// Req that the remainder is in the loop
2004+
if (!L->contains(Rem))
2005+
return false;
2006+
2007+
// Only works if the remainder amount is a loop invaraint
2008+
if (!L->isLoopInvariant(RemAmt))
2009+
return false;
2010+
2011+
// Is the PHI a loop increment?
2012+
auto LoopIncrInfo = getIVIncrement(PN, LI);
2013+
if (!LoopIncrInfo)
2014+
return false;
2015+
2016+
// We need remainder_amount % increment_amount to be zero. Increment of one
2017+
// satisfies that without any special logic and is overwhelmingly the common
2018+
// case.
2019+
if (!match(LoopIncrInfo->second, m_One()))
2020+
return false;
2021+
2022+
// Need the increment to not overflow.
2023+
if (!match(LoopIncrInfo->first, m_c_NUWAdd(m_Specific(PN), m_Value())))
2024+
return false;
2025+
2026+
// Set output variables.
2027+
RemAmtOut = RemAmt;
2028+
LoopIncrPNOut = PN;
2029+
2030+
return true;
2031+
}
2032+
2033+
// Try to transform:
2034+
//
2035+
// for(i = Start; i < End; ++i)
2036+
// Rem = (i nuw+ IncrLoopInvariant) u% RemAmtLoopInvariant;
2037+
//
2038+
// ->
2039+
//
2040+
// Rem = (Start nuw+ IncrLoopInvariant) % RemAmtLoopInvariant;
2041+
// for(i = Start; i < End; ++i, ++rem)
2042+
// Rem = rem == RemAmtLoopInvariant ? 0 : Rem;
2043+
//
2044+
// Currently only implemented for `IncrLoopInvariant` being zero.
2045+
static bool foldURemOfLoopIncrement(Instruction *Rem, const DataLayout *DL,
2046+
const LoopInfo *LI,
2047+
SmallSet<BasicBlock *, 32> &FreshBBs,
2048+
bool IsHuge) {
2049+
Value *RemAmt;
2050+
PHINode *LoopIncrPN;
2051+
if (!isRemOfLoopIncrementWithLoopInvariant(Rem, LI, RemAmt, LoopIncrPN))
2052+
return false;
2053+
2054+
// Only non-constant remainder as the extra IV is probably not profitable
2055+
// in that case.
2056+
//
2057+
// Potential TODO(1): `urem` of a const ends up as `mul` + `shift` + `add`. If
2058+
// we can rule out register pressure and ensure this `urem` is executed each
2059+
// iteration, its probably profitable to handle the const case as well.
2060+
//
2061+
// Potential TODO(2): Should we have a check for how "nested" this remainder
2062+
// operation is? The new code runs every iteration so if the remainder is
2063+
// guarded behind unlikely conditions this might not be worth it.
2064+
if (match(RemAmt, m_ImmConstant()))
2065+
return false;
2066+
2067+
Loop *L = LI->getLoopFor(LoopIncrPN->getParent());
2068+
Value *Start = LoopIncrPN->getIncomingValueForBlock(L->getLoopPreheader());
2069+
// If we can't fully optimize out the `rem`, skip this transform.
2070+
Start = simplifyURemInst(Start, RemAmt, *DL);
2071+
if (!Start)
2072+
return false;
2073+
2074+
// Create new remainder with induction variable.
2075+
Type *Ty = Rem->getType();
2076+
IRBuilder<> Builder(Rem->getContext());
2077+
2078+
Builder.SetInsertPoint(LoopIncrPN);
2079+
PHINode *NewRem = Builder.CreatePHI(Ty, 2);
2080+
2081+
Builder.SetInsertPoint(cast<Instruction>(
2082+
LoopIncrPN->getIncomingValueForBlock(L->getLoopLatch())));
2083+
// `(add (urem x, y), 1)` is always nuw.
2084+
Value *RemAdd = Builder.CreateNUWAdd(NewRem, ConstantInt::get(Ty, 1));
2085+
Value *RemCmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, RemAdd, RemAmt);
2086+
Value *RemSel =
2087+
Builder.CreateSelect(RemCmp, Constant::getNullValue(Ty), RemAdd);
2088+
2089+
NewRem->addIncoming(Start, L->getLoopPreheader());
2090+
NewRem->addIncoming(RemSel, L->getLoopLatch());
2091+
2092+
// Insert all touched BBs.
2093+
FreshBBs.insert(LoopIncrPN->getParent());
2094+
FreshBBs.insert(L->getLoopLatch());
2095+
FreshBBs.insert(Rem->getParent());
2096+
2097+
replaceAllUsesWith(Rem, NewRem, FreshBBs, IsHuge);
2098+
Rem->eraseFromParent();
2099+
return true;
2100+
}
2101+
2102+
bool CodeGenPrepare::optimizeURem(Instruction *Rem) {
2103+
if (foldURemOfLoopIncrement(Rem, DL, LI, FreshBBs, IsHugeFunc))
2104+
return true;
2105+
return false;
2106+
}
2107+
19782108
bool CodeGenPrepare::optimizeCmp(CmpInst *Cmp, ModifyDT &ModifiedDT) {
19792109
if (sinkCmpExpression(Cmp, *TLI))
19802110
return true;
@@ -8358,6 +8488,10 @@ bool CodeGenPrepare::optimizeInst(Instruction *I, ModifyDT &ModifiedDT) {
83588488
if (optimizeCmp(Cmp, ModifiedDT))
83598489
return true;
83608490

8491+
if (match(I, m_URem(m_Value(), m_Value())))
8492+
if (optimizeURem(I))
8493+
return true;
8494+
83618495
if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
83628496
LI->setMetadata(LLVMContext::MD_invariant_group, nullptr);
83638497
bool Modified = optimizeLoadExt(LI);

0 commit comments

Comments
 (0)