Skip to content

Commit 1834ec7

Browse files
committed
[InstCombine] Improve eq/ne by parts to handle ult/ugt equality pattern.
(icmp eq/ne (lshr x, C), (lshr y, C) gets optimized to `(icmp ult/uge (xor x, y), (1 << C)`. This can cause the current equal by parts detection to miss the high-bits as it may get optimized to the new pattern. This commit adds support for detecting / combining the ult/ugt pattern.
1 parent 8208f91 commit 1834ec7

File tree

2 files changed

+35
-18
lines changed

2 files changed

+35
-18
lines changed

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,13 +1146,40 @@ Value *InstCombinerImpl::foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1,
11461146
return nullptr;
11471147

11481148
CmpInst::Predicate Pred = IsAnd ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE;
1149-
if (Cmp0->getPredicate() != Pred || Cmp1->getPredicate() != Pred)
1150-
return nullptr;
1149+
auto GetMatchPart = [&](ICmpInst *Cmp,
1150+
unsigned OpNo) -> std::optional<IntPart> {
1151+
if (Pred == Cmp->getPredicate())
1152+
return matchIntPart(Cmp->getOperand(OpNo));
1153+
1154+
const APInt *C;
1155+
// (icmp eq (lshr x, C), (lshr y, C)) gets optimized to:
1156+
// (icmp ult (xor x, y), 1 << C) so also look for that.
1157+
if (Pred == CmpInst::ICMP_EQ && Cmp->getPredicate() == CmpInst::ICMP_ULT) {
1158+
if (!match(Cmp->getOperand(1), m_Power2(C)) ||
1159+
!match(Cmp->getOperand(0), m_Xor(m_Value(), m_Value())))
1160+
return std::nullopt;
1161+
}
1162+
1163+
// (icmp ne (lshr x, C), (lshr y, C)) gets optimized to:
1164+
// (icmp ugt (xor x, y), (1 << C) - 1) so also look for that.
1165+
else if (Pred == CmpInst::ICMP_NE &&
1166+
Cmp->getPredicate() == CmpInst::ICMP_UGT) {
1167+
if (!match(Cmp->getOperand(1), m_LowBitMask(C)) ||
1168+
!match(Cmp->getOperand(0), m_Xor(m_Value(), m_Value())))
1169+
return std::nullopt;
1170+
} else {
1171+
return std::nullopt;
1172+
}
1173+
1174+
unsigned From = Pred == CmpInst::ICMP_NE ? C->popcount() : C->countr_zero();
1175+
Instruction *I = cast<Instruction>(Cmp->getOperand(0));
1176+
return {{I->getOperand(OpNo), From, C->getBitWidth() - From}};
1177+
};
11511178

1152-
std::optional<IntPart> L0 = matchIntPart(Cmp0->getOperand(0));
1153-
std::optional<IntPart> R0 = matchIntPart(Cmp0->getOperand(1));
1154-
std::optional<IntPart> L1 = matchIntPart(Cmp1->getOperand(0));
1155-
std::optional<IntPart> R1 = matchIntPart(Cmp1->getOperand(1));
1179+
std::optional<IntPart> L0 = GetMatchPart(Cmp0, 0);
1180+
std::optional<IntPart> R0 = GetMatchPart(Cmp0, 1);
1181+
std::optional<IntPart> L1 = GetMatchPart(Cmp1, 0);
1182+
std::optional<IntPart> R1 = GetMatchPart(Cmp1, 1);
11561183
if (!L0 || !R0 || !L1 || !R1)
11571184
return nullptr;
11581185

llvm/test/Transforms/InstCombine/eq-of-parts.ll

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1336,12 +1336,7 @@ define i1 @ne_21_wrong_pred2(i32 %x, i32 %y) {
13361336

13371337
define i1 @eq_optimized_highbits_cmp(i32 %x, i32 %y) {
13381338
; CHECK-LABEL: @eq_optimized_highbits_cmp(
1339-
; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]]
1340-
; CHECK-NEXT: [[CMP_HI:%.*]] = icmp ult i32 [[XOR]], 33554432
1341-
; CHECK-NEXT: [[TX:%.*]] = trunc i32 [[X]] to i25
1342-
; CHECK-NEXT: [[TY:%.*]] = trunc i32 [[Y]] to i25
1343-
; CHECK-NEXT: [[CMP_LO:%.*]] = icmp eq i25 [[TX]], [[TY]]
1344-
; CHECK-NEXT: [[R:%.*]] = and i1 [[CMP_HI]], [[CMP_LO]]
1339+
; CHECK-NEXT: [[R:%.*]] = icmp eq i32 [[Y:%.*]], [[X:%.*]]
13451340
; CHECK-NEXT: ret i1 [[R]]
13461341
;
13471342
%xor = xor i32 %y, %x
@@ -1393,12 +1388,7 @@ define i1 @eq_optimized_highbits_cmp_fail_not_pow2(i32 %x, i32 %y) {
13931388

13941389
define i1 @ne_optimized_highbits_cmp(i32 %x, i32 %y) {
13951390
; CHECK-LABEL: @ne_optimized_highbits_cmp(
1396-
; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]]
1397-
; CHECK-NEXT: [[CMP_HI:%.*]] = icmp ugt i32 [[XOR]], 16777215
1398-
; CHECK-NEXT: [[TX:%.*]] = trunc i32 [[X]] to i24
1399-
; CHECK-NEXT: [[TY:%.*]] = trunc i32 [[Y]] to i24
1400-
; CHECK-NEXT: [[CMP_LO:%.*]] = icmp ne i24 [[TX]], [[TY]]
1401-
; CHECK-NEXT: [[R:%.*]] = or i1 [[CMP_HI]], [[CMP_LO]]
1391+
; CHECK-NEXT: [[R:%.*]] = icmp ne i32 [[Y:%.*]], [[X:%.*]]
14021392
; CHECK-NEXT: ret i1 [[R]]
14031393
;
14041394
%xor = xor i32 %y, %x

0 commit comments

Comments
 (0)