Skip to content

Commit c64ce8b

Browse files
committed
[CodeGenPrepare] Folding urem with loop invariant value
``` for(i = Start; i < End; ++i) Rem = (i nuw+ IncrLoopInvariant) u% RemAmtLoopInvariant; ``` -> ``` Rem = (Start nuw+ IncrLoopInvariant) % RemAmtLoopInvariant; for(i = Start; i < End; ++i, ++rem) Rem = rem == RemAmtLoopInvariant ? 0 : Rem; ``` In its current state, only if `IncrLoopInvariant` and `Start` both being zero. Alive2 seemed unable to prove this (see: https://alive2.llvm.org/ce/z/ATGDp3 which is clearly wrong but still checks out...) so wrote an exhaustive test here: https://godbolt.org/z/WYa561388 Closes #96625
1 parent f16125a commit c64ce8b

File tree

3 files changed

+245
-61
lines changed

3 files changed

+245
-61
lines changed

llvm/lib/CodeGen/CodeGenPrepare.cpp

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,7 @@ class CodeGenPrepare {
472472
bool replaceMathCmpWithIntrinsic(BinaryOperator *BO, Value *Arg0, Value *Arg1,
473473
CmpInst *Cmp, Intrinsic::ID IID);
474474
bool optimizeCmp(CmpInst *Cmp, ModifyDT &ModifiedDT);
475+
bool optimizeURem(Instruction *Rem);
475476
bool combineToUSubWithOverflow(CmpInst *Cmp, ModifyDT &ModifiedDT);
476477
bool combineToUAddWithOverflow(CmpInst *Cmp, ModifyDT &ModifiedDT);
477478
void verifyBFIUpdates(Function &F);
@@ -1975,6 +1976,132 @@ static bool foldFCmpToFPClassTest(CmpInst *Cmp, const TargetLowering &TLI,
19751976
return true;
19761977
}
19771978

1979+
static bool isRemOfLoopIncrementWithLoopInvariant(Instruction *Rem,
1980+
const LoopInfo *LI,
1981+
Value *&RemAmtOut,
1982+
PHINode *&LoopIncrPNOut) {
1983+
Value *Incr, *RemAmt;
1984+
// NB: If RemAmt is a power of 2 it *should* have been transformed by now.
1985+
if (!match(Rem, m_URem(m_Value(Incr), m_Value(RemAmt))))
1986+
return false;
1987+
1988+
// Find out loop increment PHI.
1989+
auto *PN = dyn_cast<PHINode>(Incr);
1990+
if (!PN)
1991+
return false;
1992+
1993+
// This isn't strictly necessary, what we really need is one increment and any
1994+
// amount of initial values all being the same.
1995+
if (PN->getNumIncomingValues() != 2)
1996+
return false;
1997+
1998+
// Only trivially analyzable loops.
1999+
Loop *L = LI->getLoopFor(Rem->getParent());
2000+
if (!L || !L->getLoopPreheader() || !L->getLoopLatch())
2001+
return false;
2002+
2003+
// Only works if the remainder amount is a loop invaraint
2004+
if (!L->isLoopInvariant(RemAmt))
2005+
return false;
2006+
2007+
// Is the PHI a loop increment?
2008+
auto LoopIncrInfo = getIVIncrement(PN, LI);
2009+
if (!LoopIncrInfo)
2010+
return false;
2011+
2012+
// getIVIncrement finds the loop at PN->getParent(). This might be a different
2013+
// loop from the loop with Rem->getParent().
2014+
if (L->getHeader() != PN->getParent())
2015+
return false;
2016+
2017+
// We need remainder_amount % increment_amount to be zero. Increment of one
2018+
// satisfies that without any special logic and is overwhelmingly the common
2019+
// case.
2020+
if (!match(LoopIncrInfo->second, m_One()))
2021+
return false;
2022+
2023+
// Need the increment to not overflow.
2024+
if (!match(LoopIncrInfo->first, m_NUWAdd(m_Value(), m_Value())))
2025+
return false;
2026+
2027+
// Set output variables.
2028+
RemAmtOut = RemAmt;
2029+
LoopIncrPNOut = PN;
2030+
2031+
return true;
2032+
}
2033+
2034+
// Try to transform:
2035+
//
2036+
// for(i = Start; i < End; ++i)
2037+
// Rem = (i nuw+ IncrLoopInvariant) u% RemAmtLoopInvariant;
2038+
//
2039+
// ->
2040+
//
2041+
// Rem = (Start nuw+ IncrLoopInvariant) % RemAmtLoopInvariant;
2042+
// for(i = Start; i < End; ++i, ++rem)
2043+
// Rem = rem == RemAmtLoopInvariant ? 0 : Rem;
2044+
//
2045+
// Currently only implemented for `IncrLoopInvariant` being zero.
2046+
static bool foldURemOfLoopIncrement(Instruction *Rem, const DataLayout *DL,
2047+
const LoopInfo *LI,
2048+
SmallSet<BasicBlock *, 32> &FreshBBs,
2049+
bool IsHuge) {
2050+
Value *RemAmt;
2051+
PHINode *LoopIncrPN;
2052+
if (!isRemOfLoopIncrementWithLoopInvariant(Rem, LI, RemAmt, LoopIncrPN))
2053+
return false;
2054+
2055+
// Only non-constant remainder as the extra IV is probably not profitable
2056+
// in that case.
2057+
//
2058+
// Potential TODO(1): `urem` of a const ends up as `mul` + `shift` + `add`. If
2059+
// we can rule out register pressure and ensure this `urem` is executed each
2060+
// iteration, its probably profitable to handle the const case as well.
2061+
//
2062+
// Potential TODO(2): Should we have a check for how "nested" this remainder
2063+
// operation is? The new code runs every iteration so if the remainder is
2064+
// guarded behind unlikely conditions this might not be worth it.
2065+
if (match(RemAmt, m_ImmConstant()))
2066+
return false;
2067+
Loop *L = LI->getLoopFor(Rem->getParent());
2068+
2069+
Value *Start = LoopIncrPN->getIncomingValueForBlock(L->getLoopPreheader());
2070+
2071+
// Create new remainder with induction variable.
2072+
Type *Ty = Rem->getType();
2073+
IRBuilder<> Builder(Rem->getContext());
2074+
2075+
Builder.SetInsertPoint(LoopIncrPN);
2076+
PHINode *NewRem = Builder.CreatePHI(Ty, 2);
2077+
2078+
Builder.SetInsertPoint(cast<Instruction>(
2079+
LoopIncrPN->getIncomingValueForBlock(L->getLoopLatch())));
2080+
// `(add (urem x, y), 1)` is always nuw.
2081+
Value *RemAdd = Builder.CreateNUWAdd(NewRem, ConstantInt::get(Ty, 1));
2082+
Value *RemCmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, RemAdd, RemAmt);
2083+
Value *RemSel =
2084+
Builder.CreateSelect(RemCmp, Constant::getNullValue(Ty), RemAdd);
2085+
2086+
NewRem->addIncoming(Start, L->getLoopPreheader());
2087+
NewRem->addIncoming(RemSel, L->getLoopLatch());
2088+
2089+
// Insert all touched BBs.
2090+
FreshBBs.insert(LoopIncrPN->getParent());
2091+
FreshBBs.insert(L->getLoopLatch());
2092+
FreshBBs.insert(Rem->getParent());
2093+
2094+
replaceAllUsesWith(Rem, NewRem, FreshBBs, IsHuge);
2095+
Rem->eraseFromParent();
2096+
return true;
2097+
}
2098+
2099+
bool CodeGenPrepare::optimizeURem(Instruction *Rem) {
2100+
if (foldURemOfLoopIncrement(Rem, DL, LI, FreshBBs, IsHugeFunc))
2101+
return true;
2102+
return false;
2103+
}
2104+
19782105
bool CodeGenPrepare::optimizeCmp(CmpInst *Cmp, ModifyDT &ModifiedDT) {
19792106
if (sinkCmpExpression(Cmp, *TLI))
19802107
return true;
@@ -8358,6 +8485,10 @@ bool CodeGenPrepare::optimizeInst(Instruction *I, ModifyDT &ModifiedDT) {
83588485
if (optimizeCmp(Cmp, ModifiedDT))
83598486
return true;
83608487

8488+
if (match(I, m_URem(m_Value(), m_Value())))
8489+
if (optimizeURem(I))
8490+
return true;
8491+
83618492
if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
83628493
LI->setMetadata(LLVMContext::MD_invariant_group, nullptr);
83638494
bool Modified = optimizeLoadExt(LI);

0 commit comments

Comments
 (0)