@@ -471,7 +471,7 @@ class CodeGenPrepare {
471
471
bool replaceMathCmpWithIntrinsic (BinaryOperator *BO, Value *Arg0, Value *Arg1,
472
472
CmpInst *Cmp, Intrinsic::ID IID);
473
473
bool optimizeCmp (CmpInst *Cmp, ModifyDT &ModifiedDT);
474
- bool optimizeRem (Instruction *Rem);
474
+ bool optimizeURem (Instruction *Rem);
475
475
bool combineToUSubWithOverflow (CmpInst *Cmp, ModifyDT &ModifiedDT);
476
476
bool combineToUAddWithOverflow (CmpInst *Cmp, ModifyDT &ModifiedDT);
477
477
void verifyBFIUpdates (Function &F);
@@ -1976,26 +1976,23 @@ static bool foldFCmpToFPClassTest(CmpInst *Cmp, const TargetLowering &TLI,
1976
1976
}
1977
1977
1978
1978
static bool isRemOfLoopIncrementWithLoopInvariant (
1979
- Value *Rem, const LoopInfo *LI, Value *&RemAmtOut,
1980
- std::optional<bool > &AddOrSubOut, Value *&AddOrSubOffsetOut ,
1981
- PHINode *&LoopIncrPNOut) {
1979
+ Instruction *Rem, const LoopInfo *LI, Value *&RemAmtOut,
1980
+ std::optional<bool > &AddOrSubOut, Value *&AddOrSubInstOut ,
1981
+ Value *&AddOrSubOffsetOut, PHINode *&LoopIncrPNOut) {
1982
1982
Value *Incr, *RemAmt;
1983
- if (!isa<Instruction>(Rem))
1984
- return false ;
1985
1983
// NB: If RemAmt is a power of 2 it *should* have been transformed by now.
1986
1984
if (!match (Rem, m_URem (m_Value (Incr), m_Value (RemAmt))))
1987
1985
return false ;
1988
1986
1989
1987
// Only trivially analyzable loops.
1990
- Loop *L = LI->getLoopFor (cast<Instruction>(Rem)->getParent ());
1991
- if (L == nullptr || L->getLoopPreheader () == nullptr ||
1992
- L->getLoopLatch () == nullptr )
1988
+ Loop *L = LI->getLoopFor (Rem->getParent ());
1989
+ if (!L || !L->getLoopPreheader () || !L->getLoopLatch ())
1993
1990
return false ;
1994
1991
1995
1992
std::optional<bool > AddOrSub;
1996
1993
Value *AddOrSubOffset;
1997
1994
// Find out loop increment PHI.
1998
- PHINode *PN = dyn_cast<PHINode>(Incr);
1995
+ auto *PN = dyn_cast<PHINode>(Incr);
1999
1996
if (PN != nullptr ) {
2000
1997
AddOrSub = std::nullopt;
2001
1998
AddOrSubOffset = nullptr ;
@@ -2009,6 +2006,8 @@ static bool isRemOfLoopIncrementWithLoopInvariant(
2009
2006
else
2010
2007
return false ;
2011
2008
2009
+ AddOrSubInstOut = Incr;
2010
+
2012
2011
PN = dyn_cast<PHINode>(V0);
2013
2012
if (PN != nullptr ) {
2014
2013
AddOrSubOffset = V1;
@@ -2018,7 +2017,7 @@ static bool isRemOfLoopIncrementWithLoopInvariant(
2018
2017
}
2019
2018
}
2020
2019
2021
- if (PN == nullptr )
2020
+ if (!PN )
2022
2021
return false ;
2023
2022
2024
2023
// This isn't strictly necessary, what we really need is one increment and any
@@ -2032,7 +2031,12 @@ static bool isRemOfLoopIncrementWithLoopInvariant(
2032
2031
2033
2032
// Is the PHI a loop increment?
2034
2033
auto LoopIncrInfo = getIVIncrement (PN, LI);
2035
- if (!LoopIncrInfo.has_value ())
2034
+ if (!LoopIncrInfo)
2035
+ return false ;
2036
+
2037
+ // getIVIncrement finds the loop at PN->getParent(). This might be a different
2038
+ // loop from the loop with Rem->getParent().
2039
+ if (L->getHeader () != PN->getParent ())
2036
2040
return false ;
2037
2041
2038
2042
// We need remainder_amount % increment_amount to be zero. Increment of one
@@ -2045,11 +2049,6 @@ static bool isRemOfLoopIncrementWithLoopInvariant(
2045
2049
if (!match (LoopIncrInfo->first , m_NUWAdd (m_Value (), m_Value ())))
2046
2050
return false ;
2047
2051
2048
- // Need unique loop preheader and latch.
2049
- if (PN->getBasicBlockIndex (L->getLoopLatch ()) < 0 ||
2050
- PN->getBasicBlockIndex (L->getLoopPreheader ()) < 0 )
2051
- return false ;
2052
-
2053
2052
// Set output variables.
2054
2053
RemAmtOut = RemAmt;
2055
2054
LoopIncrPNOut = PN;
@@ -2071,20 +2070,19 @@ static bool isRemOfLoopIncrementWithLoopInvariant(
2071
2070
// Rem = rem == RemAmtLoopInvariant ? 0 : Rem;
2072
2071
//
2073
2072
// Currently only implemented for `Start` and `IncrLoopInvariant` being zero.
2074
- static bool foldURemOfLoopIncrement (Instruction *Rem, const LoopInfo *LI,
2073
+ static bool foldURemOfLoopIncrement (Instruction *Rem, const DataLayout *DL,
2074
+ const LoopInfo *LI,
2075
2075
SmallSet<BasicBlock *, 32 > &FreshBBs,
2076
2076
bool IsHuge) {
2077
2077
std::optional<bool > AddOrSub;
2078
- Value *AddOrSubOffset, *RemAmt;
2078
+ Value *AddOrSubOffset, *RemAmt, *AddOrSubInst ;
2079
2079
PHINode *LoopIncrPN;
2080
- if (!isRemOfLoopIncrementWithLoopInvariant (Rem, LI, RemAmt, AddOrSub,
2081
- AddOrSubOffset, LoopIncrPN))
2080
+ if (!isRemOfLoopIncrementWithLoopInvariant (
2081
+ Rem, LI, RemAmt, AddOrSub, AddOrSubInst, AddOrSubOffset, LoopIncrPN))
2082
2082
return false ;
2083
2083
2084
2084
// Only non-constant remainder as the extra IV is probably not profitable
2085
- // in that case. Further, since remainder amount is non-constant, only handle
2086
- // case where `IncrLoopInvariant` and `Start` are 0 to entirely eliminate the
2087
- // rem (as opposed to just hoisting it outside of the loop).
2085
+ // in that case.
2088
2086
//
2089
2087
// Potential TODO(1): `urem` of a const ends up as `mul` + `shift` + `add`. If
2090
2088
// we can rule out register pressure and ensure this `urem` is executed each
@@ -2093,12 +2091,37 @@ static bool foldURemOfLoopIncrement(Instruction *Rem, const LoopInfo *LI,
2093
2091
// Potential TODO(2): Should we have a check for how "nested" this remainder
2094
2092
// operation is? The new code runs every iteration so if the remainder is
2095
2093
// guarded behind unlikely conditions this might not be worth it.
2096
- if (AddOrSub. has_value () || match (RemAmt, m_ImmConstant ()))
2094
+ if (match (RemAmt, m_ImmConstant ()))
2097
2095
return false ;
2098
2096
Loop *L = LI->getLoopFor (Rem->getParent ());
2099
- if (!match (LoopIncrPN->getIncomingValueForBlock (L->getLoopPreheader ()),
2100
- m_Zero ()))
2101
- return false ;
2097
+
2098
+ // If we have add/sub create initial value for remainder.
2099
+ // The logic here is:
2100
+ // (urem (add/sub nuw Start, IncrLoopInvariant), RemAmtLoopInvariant
2101
+ //
2102
+ // Only proceed if the expression simplifies (otherwise we can't fully
2103
+ // optimize out the urem).
2104
+ Value *Start = LoopIncrPN->getIncomingValueForBlock (L->getLoopPreheader ());
2105
+ if (AddOrSub) {
2106
+ assert (AddOrSubOffset && AddOrSubInst &&
2107
+ " We found an add/sub but missing values" );
2108
+ // Without dom-condition/assumption cache we aren't likely to get much out
2109
+ // of a context instruction.
2110
+ const SimplifyQuery Q (*DL);
2111
+ bool NSW = cast<OverflowingBinaryOperator>(AddOrSubInst)->hasNoSignedWrap ();
2112
+ if (*AddOrSub)
2113
+ Start = simplifyAddInst (Start, AddOrSubOffset, /* IsNSW=*/ NSW,
2114
+ /* IsNUW=*/ true , Q);
2115
+ else
2116
+ Start = simplifySubInst (Start, AddOrSubOffset, /* IsNSW=*/ NSW,
2117
+ /* IsNUW=*/ true , Q);
2118
+ if (!Start)
2119
+ return false ;
2120
+
2121
+ Start = simplifyURemInst (Start, RemAmt, Q);
2122
+ if (!Start)
2123
+ return false ;
2124
+ }
2102
2125
2103
2126
// Create new remainder with induction variable.
2104
2127
Type *Ty = Rem->getType ();
@@ -2115,7 +2138,7 @@ static bool foldURemOfLoopIncrement(Instruction *Rem, const LoopInfo *LI,
2115
2138
Value *RemSel =
2116
2139
Builder.CreateSelect (RemCmp, Constant::getNullValue (Ty), RemAdd);
2117
2140
2118
- NewRem->addIncoming (Constant::getNullValue (Ty) , L->getLoopPreheader ());
2141
+ NewRem->addIncoming (Start , L->getLoopPreheader ());
2119
2142
NewRem->addIncoming (RemSel, L->getLoopLatch ());
2120
2143
2121
2144
// Insert all touched BBs.
@@ -2125,11 +2148,13 @@ static bool foldURemOfLoopIncrement(Instruction *Rem, const LoopInfo *LI,
2125
2148
2126
2149
replaceAllUsesWith (Rem, NewRem, FreshBBs, IsHuge);
2127
2150
Rem->eraseFromParent ();
2151
+ if (AddOrSubInst && AddOrSubInst->use_empty ())
2152
+ cast<Instruction>(AddOrSubInst)->eraseFromParent ();
2128
2153
return true ;
2129
2154
}
2130
2155
2131
- bool CodeGenPrepare::optimizeRem (Instruction *Rem) {
2132
- if (foldURemOfLoopIncrement (Rem, LI, FreshBBs, IsHugeFunc))
2156
+ bool CodeGenPrepare::optimizeURem (Instruction *Rem) {
2157
+ if (foldURemOfLoopIncrement (Rem, DL, LI, FreshBBs, IsHugeFunc))
2133
2158
return true ;
2134
2159
return false ;
2135
2160
}
@@ -8520,9 +8545,8 @@ bool CodeGenPrepare::optimizeInst(Instruction *I, ModifyDT &ModifiedDT) {
8520
8545
if (optimizeCmp (Cmp, ModifiedDT))
8521
8546
return true ;
8522
8547
8523
- if (match (I, m_URem (m_Value (), m_Value ())) ||
8524
- match (I, m_SRem (m_Value (), m_Value ())))
8525
- if (optimizeRem (I))
8548
+ if (match (I, m_URem (m_Value (), m_Value ())))
8549
+ if (optimizeURem (I))
8526
8550
return true ;
8527
8551
8528
8552
if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
0 commit comments