Skip to content

[CodeGenPrepare] Folding urem with loop invariant value as remainder #96625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions llvm/lib/CodeGen/CodeGenPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ class CodeGenPrepare {
bool replaceMathCmpWithIntrinsic(BinaryOperator *BO, Value *Arg0, Value *Arg1,
CmpInst *Cmp, Intrinsic::ID IID);
bool optimizeCmp(CmpInst *Cmp, ModifyDT &ModifiedDT);
bool optimizeURem(Instruction *Rem);
bool combineToUSubWithOverflow(CmpInst *Cmp, ModifyDT &ModifiedDT);
bool combineToUAddWithOverflow(CmpInst *Cmp, ModifyDT &ModifiedDT);
void verifyBFIUpdates(Function &F);
Expand Down Expand Up @@ -1974,6 +1975,133 @@ static bool foldFCmpToFPClassTest(CmpInst *Cmp, const TargetLowering &TLI,
return true;
}

static bool isRemOfLoopIncrementWithLoopInvariant(Instruction *Rem,
const LoopInfo *LI,
Value *&RemAmtOut,
PHINode *&LoopIncrPNOut) {
Value *Incr, *RemAmt;
// NB: If RemAmt is a power of 2 it *should* have been transformed by now.
if (!match(Rem, m_URem(m_Value(Incr), m_Value(RemAmt))))
return false;

// Only trivially analyzable loops.
Loop *L = LI->getLoopFor(Rem->getParent());
if (!L || !L->getLoopPreheader() || !L->getLoopLatch())
Copy link
Contributor

@nikic nikic Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC getLoopPreheader is expensive, so maybe move this after the phi check.

return false;

// Find out loop increment PHI.
auto *PN = dyn_cast<PHINode>(Incr);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Drop this newline.

if (!PN)
return false;

// This isn't strictly necessary, what we really need is one increment and any
// amount of initial values all being the same.
if (PN->getNumIncomingValues() != 2)
return false;

// Only works if the remainder amount is a loop invaraint
if (!L->isLoopInvariant(RemAmt))
return false;

// Is the PHI a loop increment?
auto LoopIncrInfo = getIVIncrement(PN, LI);
if (!LoopIncrInfo)
return false;

// getIVIncrement finds the loop at PN->getParent(). This might be a different
// loop from the loop with Rem->getParent().
if (L->getHeader() != PN->getParent())
Copy link
Contributor

@nikic nikic Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think technically the right thing to do would be to have L be the loop of the phi node, not the rem, and only check that L also contains rem (which is not guaranteed as we don't require LCSSA in CGP). The loop of the phi node is really what you care about, the urem itself can be in a nested loop without issue (I think?)

But I agree it's best to ignore this case for now.

return false;

// We need remainder_amount % increment_amount to be zero. Increment of one
// satisfies that without any special logic and is overwhelmingly the common
// case.
if (!match(LoopIncrInfo->second, m_One()))
return false;

// Need the increment to not overflow.
if (!match(LoopIncrInfo->first, m_NUWAdd(m_Value(), m_Value())))
return false;

// Set output variables.
RemAmtOut = RemAmt;
LoopIncrPNOut = PN;

return true;
}

// Try to transform:
//
// for(i = Start; i < End; ++i)
// Rem = (i nuw+ IncrLoopInvariant) u% RemAmtLoopInvariant;
//
// ->
//
// Rem = (Start nuw+ IncrLoopInvariant) % RemAmtLoopInvariant;
// for(i = Start; i < End; ++i, ++rem)
// Rem = rem == RemAmtLoopInvariant ? 0 : Rem;
//
// Currently only implemented for `IncrLoopInvariant` being zero.
static bool foldURemOfLoopIncrement(Instruction *Rem, const DataLayout *DL,
const LoopInfo *LI,
SmallSet<BasicBlock *, 32> &FreshBBs,
bool IsHuge) {
Value *RemAmt;
PHINode *LoopIncrPN;
if (!isRemOfLoopIncrementWithLoopInvariant(Rem, LI, RemAmt, LoopIncrPN))
return false;

// Only non-constant remainder as the extra IV is probably not profitable
// in that case.
//
// Potential TODO(1): `urem` of a const ends up as `mul` + `shift` + `add`. If
// we can rule out register pressure and ensure this `urem` is executed each
// iteration, its probably profitable to handle the const case as well.
//
// Potential TODO(2): Should we have a check for how "nested" this remainder
// operation is? The new code runs every iteration so if the remainder is
// guarded behind unlikely conditions this might not be worth it.
if (match(RemAmt, m_ImmConstant()))
return false;
Loop *L = LI->getLoopFor(Rem->getParent());

Value *Start = LoopIncrPN->getIncomingValueForBlock(L->getLoopPreheader());

// Create new remainder with induction variable.
Type *Ty = Rem->getType();
IRBuilder<> Builder(Rem->getContext());

Builder.SetInsertPoint(LoopIncrPN);
PHINode *NewRem = Builder.CreatePHI(Ty, 2);

Builder.SetInsertPoint(cast<Instruction>(
LoopIncrPN->getIncomingValueForBlock(L->getLoopLatch())));
// `(add (urem x, y), 1)` is always nuw.
Value *RemAdd = Builder.CreateNUWAdd(NewRem, ConstantInt::get(Ty, 1));
Value *RemCmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, RemAdd, RemAmt);
Value *RemSel =
Builder.CreateSelect(RemCmp, Constant::getNullValue(Ty), RemAdd);

NewRem->addIncoming(Start, L->getLoopPreheader());
NewRem->addIncoming(RemSel, L->getLoopLatch());

// Insert all touched BBs.
FreshBBs.insert(LoopIncrPN->getParent());
FreshBBs.insert(L->getLoopLatch());
FreshBBs.insert(Rem->getParent());

replaceAllUsesWith(Rem, NewRem, FreshBBs, IsHuge);
Rem->eraseFromParent();
return true;
}

bool CodeGenPrepare::optimizeURem(Instruction *Rem) {
if (foldURemOfLoopIncrement(Rem, DL, LI, FreshBBs, IsHugeFunc))
return true;
return false;
}

bool CodeGenPrepare::optimizeCmp(CmpInst *Cmp, ModifyDT &ModifiedDT) {
if (sinkCmpExpression(Cmp, *TLI))
return true;
Expand Down Expand Up @@ -8360,6 +8488,10 @@ bool CodeGenPrepare::optimizeInst(Instruction *I, ModifyDT &ModifiedDT) {
if (optimizeCmp(Cmp, ModifiedDT))
return true;

if (match(I, m_URem(m_Value(), m_Value())))
if (optimizeURem(I))
return true;

if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
LI->setMetadata(LLVMContext::MD_invariant_group, nullptr);
bool Modified = optimizeLoadExt(LI);
Expand Down
Loading
Loading