@@ -2283,6 +2283,8 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
2283
2283
ICmpInst::Predicate PredL = LHS->getPredicate (), PredR = RHS->getPredicate ();
2284
2284
Value *LHS0 = LHS->getOperand (0 ), *RHS0 = RHS->getOperand (0 );
2285
2285
Value *LHS1 = LHS->getOperand (1 ), *RHS1 = RHS->getOperand (1 );
2286
+ auto *LHSC = dyn_cast<ConstantInt>(LHS1);
2287
+ auto *RHSC = dyn_cast<ConstantInt>(RHS1);
2286
2288
2287
2289
// Fold (icmp ult/ule (A + C1), C3) | (icmp ult/ule (A + C2), C3)
2288
2290
// --> (icmp ult/ule ((A & ~(C1 ^ C2)) + max(C1, C2)), C3)
@@ -2294,43 +2296,42 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
2294
2296
// 3) C1 ^ C2 is one-bit mask.
2295
2297
// 4) LowRange1 ^ LowRange2 and HighRange1 ^ HighRange2 are one-bit mask.
2296
2298
// This implies all values in the two ranges differ by exactly one bit.
2297
- const APInt *LHSVal, *RHSVal;
2298
2299
if ((PredL == ICmpInst::ICMP_ULT || PredL == ICmpInst::ICMP_ULE) &&
2299
- PredL == PredR && LHS->getType () == RHS->getType () &&
2300
- LHS->getType ()->isIntOrIntVectorTy () && match (LHS1, m_APInt (LHSVal)) &&
2301
- match (RHS1, m_APInt (RHSVal)) && *LHSVal == *RHSVal && LHS->hasOneUse () &&
2302
- RHS->hasOneUse ()) {
2303
- Value *AddOpnd;
2304
- const APInt *LAddVal, *RAddVal;
2305
- if (match (LHS0, m_Add (m_Value (AddOpnd), m_APInt (LAddVal))) &&
2306
- match (RHS0, m_Add (m_Specific (AddOpnd), m_APInt (RAddVal))) &&
2307
- LAddVal->ugt (*LHSVal) && RAddVal->ugt (*LHSVal)) {
2308
-
2309
- APInt DiffC = *LAddVal ^ *RAddVal;
2310
- if (DiffC.isPowerOf2 ()) {
2311
- const APInt *MaxAddC = nullptr ;
2312
- if (LAddVal->ult (*RAddVal))
2313
- MaxAddC = RAddVal;
2300
+ PredL == PredR && LHSC && RHSC && LHS->hasOneUse () && RHS->hasOneUse () &&
2301
+ LHSC->getType () == RHSC->getType () &&
2302
+ LHSC->getValue () == (RHSC->getValue ())) {
2303
+
2304
+ Value *LAddOpnd, *RAddOpnd;
2305
+ ConstantInt *LAddC, *RAddC;
2306
+ if (match (LHS0, m_Add (m_Value (LAddOpnd), m_ConstantInt (LAddC))) &&
2307
+ match (RHS0, m_Add (m_Value (RAddOpnd), m_ConstantInt (RAddC))) &&
2308
+ LAddC->getValue ().ugt (LHSC->getValue ()) &&
2309
+ RAddC->getValue ().ugt (LHSC->getValue ())) {
2310
+
2311
+ APInt DiffC = LAddC->getValue () ^ RAddC->getValue ();
2312
+ if (LAddOpnd == RAddOpnd && DiffC.isPowerOf2 ()) {
2313
+ ConstantInt *MaxAddC = nullptr ;
2314
+ if (LAddC->getValue ().ult (RAddC->getValue ()))
2315
+ MaxAddC = RAddC;
2314
2316
else
2315
- MaxAddC = LAddVal ;
2317
+ MaxAddC = LAddC ;
2316
2318
2317
- APInt RRangeLow = -*RAddVal ;
2318
- APInt RRangeHigh = RRangeLow + *LHSVal ;
2319
- APInt LRangeLow = -*LAddVal ;
2320
- APInt LRangeHigh = LRangeLow + *LHSVal ;
2319
+ APInt RRangeLow = -RAddC-> getValue () ;
2320
+ APInt RRangeHigh = RRangeLow + LHSC-> getValue () ;
2321
+ APInt LRangeLow = -LAddC-> getValue () ;
2322
+ APInt LRangeHigh = LRangeLow + LHSC-> getValue () ;
2321
2323
APInt LowRangeDiff = RRangeLow ^ LRangeLow;
2322
2324
APInt HighRangeDiff = RRangeHigh ^ LRangeHigh;
2323
2325
APInt RangeDiff = LRangeLow.sgt (RRangeLow) ? LRangeLow - RRangeLow
2324
2326
: RRangeLow - LRangeLow;
2325
2327
2326
2328
if (LowRangeDiff.isPowerOf2 () && LowRangeDiff == HighRangeDiff &&
2327
- RangeDiff.ugt (*LHSVal)) {
2328
- Value *NewAnd = Builder.CreateAnd (
2329
- AddOpnd, ConstantInt::get (LHS0->getType (), ~DiffC));
2330
- Value *NewAdd = Builder.CreateAdd (
2331
- NewAnd, ConstantInt::get (LHS0->getType (), *MaxAddC));
2332
- return Builder.CreateICmp (LHS->getPredicate (), NewAdd,
2333
- ConstantInt::get (LHS0->getType (), *LHSVal));
2329
+ RangeDiff.ugt (LHSC->getValue ())) {
2330
+ Value *MaskC = ConstantInt::get (LAddC->getType (), ~DiffC);
2331
+
2332
+ Value *NewAnd = Builder.CreateAnd (LAddOpnd, MaskC);
2333
+ Value *NewAdd = Builder.CreateAdd (NewAnd, MaxAddC);
2334
+ return Builder.CreateICmp (LHS->getPredicate (), NewAdd, LHSC);
2334
2335
}
2335
2336
}
2336
2337
}
@@ -2416,8 +2417,6 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
2416
2417
}
2417
2418
2418
2419
// This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2).
2419
- auto *LHSC = dyn_cast<ConstantInt>(LHS1);
2420
- auto *RHSC = dyn_cast<ConstantInt>(RHS1);
2421
2420
if (!LHSC || !RHSC)
2422
2421
return nullptr ;
2423
2422
0 commit comments