Skip to content

Commit f89cd57

Browse files
committed
[InstCombine] Fold Xor with or disjoint
Implement a missing optimization to fold (A | B) ^ C to (A ^ C) ^ B
1 parent 77b75e6 commit f89cd57

File tree

2 files changed

+70
-60
lines changed

2 files changed

+70
-60
lines changed

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 66 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -101,16 +101,16 @@ Value *InstCombinerImpl::insertRangeTest(Value *V, const APInt &Lo,
101101
/// (icmp eq (A & B), A) equals (icmp ne (A & B), 0)
102102
/// (icmp ne (A & B), A) equals (icmp eq (A & B), 0)
103103
enum MaskedICmpType {
104-
AMask_AllOnes = 1,
105-
AMask_NotAllOnes = 2,
106-
BMask_AllOnes = 4,
107-
BMask_NotAllOnes = 8,
108-
Mask_AllZeros = 16,
109-
Mask_NotAllZeros = 32,
110-
AMask_Mixed = 64,
111-
AMask_NotMixed = 128,
112-
BMask_Mixed = 256,
113-
BMask_NotMixed = 512
104+
AMask_AllOnes = 1,
105+
AMask_NotAllOnes = 2,
106+
BMask_AllOnes = 4,
107+
BMask_NotAllOnes = 8,
108+
Mask_AllZeros = 16,
109+
Mask_NotAllZeros = 32,
110+
AMask_Mixed = 64,
111+
AMask_NotMixed = 128,
112+
BMask_Mixed = 256,
113+
BMask_NotMixed = 512
114114
};
115115

116116
/// Return the set of patterns (from MaskedICmpType) that (icmp SCC (A & B), C)
@@ -172,15 +172,16 @@ static unsigned conjugateICmpMask(unsigned Mask) {
172172
<< 1;
173173

174174
NewMask |= (Mask & (AMask_NotAllOnes | BMask_NotAllOnes | Mask_NotAllZeros |
175-
AMask_NotMixed | BMask_NotMixed))
176-
>> 1;
175+
AMask_NotMixed | BMask_NotMixed)) >>
176+
1;
177177

178178
return NewMask;
179179
}
180180

181181
// Adapts the external decomposeBitTestICmp for local use.
182-
static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pred,
183-
Value *&X, Value *&Y, Value *&Z) {
182+
static bool decomposeBitTestICmp(Value *LHS, Value *RHS,
183+
CmpInst::Predicate &Pred, Value *&X, Value *&Y,
184+
Value *&Z) {
184185
APInt Mask;
185186
if (!llvm::decomposeBitTestICmp(LHS, RHS, Pred, X, Mask))
186187
return false;
@@ -519,9 +520,9 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
519520
if (Mask == 0) {
520521
// Even if the two sides don't share a common pattern, check if folding can
521522
// still happen.
522-
if (Value *V = foldLogOpOfMaskedICmpsAsymmetric(
523-
LHS, RHS, IsAnd, A, B, C, D, E, PredL, PredR, LHSMask, RHSMask,
524-
Builder))
523+
if (Value *V = foldLogOpOfMaskedICmpsAsymmetric(LHS, RHS, IsAnd, A, B, C, D,
524+
E, PredL, PredR, LHSMask,
525+
RHSMask, Builder))
525526
return V;
526527
return nullptr;
527528
}
@@ -680,16 +681,16 @@ Value *InstCombinerImpl::simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1,
680681
if (!RangeStart)
681682
return nullptr;
682683

683-
ICmpInst::Predicate Pred0 = (Inverted ? Cmp0->getInversePredicate() :
684-
Cmp0->getPredicate());
684+
ICmpInst::Predicate Pred0 =
685+
(Inverted ? Cmp0->getInversePredicate() : Cmp0->getPredicate());
685686

686687
// Accept x > -1 or x >= 0 (after potentially inverting the predicate).
687688
if (!((Pred0 == ICmpInst::ICMP_SGT && RangeStart->isMinusOne()) ||
688689
(Pred0 == ICmpInst::ICMP_SGE && RangeStart->isZero())))
689690
return nullptr;
690691

691-
ICmpInst::Predicate Pred1 = (Inverted ? Cmp1->getInversePredicate() :
692-
Cmp1->getPredicate());
692+
ICmpInst::Predicate Pred1 =
693+
(Inverted ? Cmp1->getInversePredicate() : Cmp1->getPredicate());
693694

694695
Value *Input = Cmp0->getOperand(0);
695696
Value *RangeEnd;
@@ -707,9 +708,14 @@ Value *InstCombinerImpl::simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1,
707708
// Check the upper range comparison, e.g. x < n
708709
ICmpInst::Predicate NewPred;
709710
switch (Pred1) {
710-
case ICmpInst::ICMP_SLT: NewPred = ICmpInst::ICMP_ULT; break;
711-
case ICmpInst::ICMP_SLE: NewPred = ICmpInst::ICMP_ULE; break;
712-
default: return nullptr;
711+
case ICmpInst::ICMP_SLT:
712+
NewPred = ICmpInst::ICMP_ULT;
713+
break;
714+
case ICmpInst::ICMP_SLE:
715+
NewPred = ICmpInst::ICMP_ULE;
716+
break;
717+
default:
718+
return nullptr;
713719
}
714720

715721
// This simplification is only valid if the upper range is not negative.
@@ -785,8 +791,7 @@ Value *InstCombinerImpl::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS,
785791
if (L2 == R1)
786792
std::swap(L1, L2);
787793

788-
if (L1 == R1 &&
789-
isKnownToBeAPowerOfTwo(L2, false, 0, CxtI) &&
794+
if (L1 == R1 && isKnownToBeAPowerOfTwo(L2, false, 0, CxtI) &&
790795
isKnownToBeAPowerOfTwo(R2, false, 0, CxtI)) {
791796
// If this is a logical and/or, then we must prevent propagation of a
792797
// poison value from the RHS by inserting freeze.
@@ -1636,8 +1641,8 @@ static Instruction *reassociateFCmps(BinaryOperator &BO,
16361641

16371642
// Match inner binop and the predicate for combining 2 NAN checks into 1.
16381643
Value *BO10, *BO11;
1639-
FCmpInst::Predicate NanPred = Opcode == Instruction::And ? FCmpInst::FCMP_ORD
1640-
: FCmpInst::FCMP_UNO;
1644+
FCmpInst::Predicate NanPred =
1645+
Opcode == Instruction::And ? FCmpInst::FCMP_ORD : FCmpInst::FCMP_UNO;
16411646
if (!match(Op0, m_SpecificFCmp(NanPred, m_Value(X), m_AnyZeroFP())) ||
16421647
!match(Op1, m_BinOp(Opcode, m_Value(BO10), m_Value(BO11))))
16431648
return nullptr;
@@ -1666,8 +1671,7 @@ static Instruction *reassociateFCmps(BinaryOperator &BO,
16661671
/// Match variations of De Morgan's Laws:
16671672
/// (~A & ~B) == (~(A | B))
16681673
/// (~A | ~B) == (~(A & B))
1669-
static Instruction *matchDeMorgansLaws(BinaryOperator &I,
1670-
InstCombiner &IC) {
1674+
static Instruction *matchDeMorgansLaws(BinaryOperator &I, InstCombiner &IC) {
16711675
const Instruction::BinaryOps Opcode = I.getOpcode();
16721676
assert((Opcode == Instruction::And || Opcode == Instruction::Or) &&
16731677
"Trying to match De Morgan's Laws with something other than and/or");
@@ -1841,10 +1845,10 @@ Instruction *InstCombinerImpl::foldCastedBitwiseLogic(BinaryOperator &I) {
18411845
Value *Cast1Src = Cast1->getOperand(0);
18421846

18431847
// fold logic(cast(A), cast(B)) -> cast(logic(A, B))
1844-
if ((Cast0->hasOneUse() || Cast1->hasOneUse()) &&
1845-
shouldOptimizeCast(Cast0) && shouldOptimizeCast(Cast1)) {
1846-
Value *NewOp = Builder.CreateBinOp(LogicOpc, Cast0Src, Cast1Src,
1847-
I.getName());
1848+
if ((Cast0->hasOneUse() || Cast1->hasOneUse()) && shouldOptimizeCast(Cast0) &&
1849+
shouldOptimizeCast(Cast1)) {
1850+
Value *NewOp =
1851+
Builder.CreateBinOp(LogicOpc, Cast0Src, Cast1Src, I.getName());
18481852
return CastInst::Create(CastOpcode, NewOp, DestTy);
18491853
}
18501854

@@ -2530,7 +2534,7 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
25302534
int Log2ShiftC = ShiftC->exactLogBase2();
25312535
int Log2C = C->exactLogBase2();
25322536
bool IsShiftLeft =
2533-
cast<BinaryOperator>(Op0)->getOpcode() == Instruction::Shl;
2537+
cast<BinaryOperator>(Op0)->getOpcode() == Instruction::Shl;
25342538
int BitNum = IsShiftLeft ? Log2C - Log2ShiftC : Log2ShiftC - Log2C;
25352539
assert(BitNum >= 0 && "Expected demanded bits to handle impossible mask");
25362540
Value *Cmp = Builder.CreateICmpEQ(X, ConstantInt::get(Ty, BitNum));
@@ -3475,8 +3479,8 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
34753479
}
34763480
} else {
34773481
if ((TrueIfSignedL && !TrueIfSignedR &&
3478-
match(LHS0, m_And(m_Value(X), m_Value(Y))) &&
3479-
match(RHS0, m_c_Or(m_Specific(X), m_Specific(Y)))) ||
3482+
match(LHS0, m_And(m_Value(X), m_Value(Y))) &&
3483+
match(RHS0, m_c_Or(m_Specific(X), m_Specific(Y)))) ||
34803484
(!TrueIfSignedL && TrueIfSignedR &&
34813485
match(LHS0, m_Or(m_Value(X), m_Value(Y))) &&
34823486
match(RHS0, m_c_And(m_Specific(X), m_Specific(Y))))) {
@@ -4163,8 +4167,8 @@ Value *InstCombinerImpl::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS,
41634167
isSignBitCheck(PredL, *LC, TrueIfSignedL) &&
41644168
isSignBitCheck(PredR, *RC, TrueIfSignedR)) {
41654169
Value *XorLR = Builder.CreateXor(LHS0, RHS0);
4166-
return TrueIfSignedL == TrueIfSignedR ? Builder.CreateIsNeg(XorLR) :
4167-
Builder.CreateIsNotNeg(XorLR);
4170+
return TrueIfSignedL == TrueIfSignedR ? Builder.CreateIsNeg(XorLR)
4171+
: Builder.CreateIsNotNeg(XorLR);
41684172
}
41694173

41704174
// Fold (icmp pred1 X, C1) ^ (icmp pred2 X, C2)
@@ -4343,8 +4347,8 @@ static Instruction *canonicalizeAbs(BinaryOperator &Xor,
43434347
Type *Ty = Xor.getType();
43444348
Value *A;
43454349
const APInt *ShAmt;
4346-
if (match(Op1, m_AShr(m_Value(A), m_APInt(ShAmt))) &&
4347-
Op1->hasNUses(2) && *ShAmt == Ty->getScalarSizeInBits() - 1 &&
4350+
if (match(Op1, m_AShr(m_Value(A), m_APInt(ShAmt))) && Op1->hasNUses(2) &&
4351+
*ShAmt == Ty->getScalarSizeInBits() - 1 &&
43484352
match(Op0, m_OneUse(m_c_Add(m_Specific(A), m_Specific(Op1))))) {
43494353
// Op1 = ashr i32 A, 31 ; smear the sign bit
43504354
// xor (add A, Op1), Op1 ; add -1 and flip bits if negative
@@ -4580,7 +4584,8 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) {
45804584

45814585
// Move a 'not' ahead of casts of a bool to enable logic reduction:
45824586
// not (bitcast (sext i1 X)) --> bitcast (sext (not i1 X))
4583-
if (match(NotOp, m_OneUse(m_BitCast(m_OneUse(m_SExt(m_Value(X)))))) && X->getType()->isIntOrIntVectorTy(1)) {
4587+
if (match(NotOp, m_OneUse(m_BitCast(m_OneUse(m_SExt(m_Value(X)))))) &&
4588+
X->getType()->isIntOrIntVectorTy(1)) {
45844589
Type *SextTy = cast<BitCastOperator>(NotOp)->getSrcTy();
45854590
Value *NotX = Builder.CreateNot(X);
45864591
Value *Sext = Builder.CreateSExt(NotX, SextTy);
@@ -4693,7 +4698,21 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
46934698
// calls in there are unnecessary as SimplifyDemandedInstructionBits should
46944699
// have already taken care of those cases.
46954700
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
4696-
Value *M;
4701+
Value *X, *Y, *M;
4702+
4703+
// (A | B) ^ C -> (A ^ C) ^ B
4704+
// C ^ (A | B) -> B ^ (A ^ C)
4705+
if (match(&I, m_c_Xor(m_OneUse(m_c_DisjointOr(m_Value(X), m_Value(Y))),
4706+
m_Value(M)))) {
4707+
if (Value *XorAC = simplifyBinOp(Instruction::Xor, X, M, SQ)) {
4708+
return BinaryOperator::CreateXor(XorAC, Y);
4709+
}
4710+
4711+
if (Value *XorBC = simplifyBinOp(Instruction::Xor, Y, M, SQ)) {
4712+
return BinaryOperator::CreateXor(XorBC, X);
4713+
}
4714+
}
4715+
46974716
if (match(&I, m_c_Xor(m_c_And(m_Not(m_Value(M)), m_Value()),
46984717
m_c_And(m_Deferred(M), m_Value())))) {
46994718
if (isGuaranteedNotToBeUndef(M))
@@ -4705,7 +4724,6 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
47054724
if (Instruction *Xor = visitMaskedMerge(I, Builder))
47064725
return Xor;
47074726

4708-
Value *X, *Y;
47094727
Constant *C1;
47104728
if (match(Op1, m_Constant(C1))) {
47114729
Constant *C2;
@@ -4870,14 +4888,14 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
48704888
// (A ^ B) ^ (A | C) --> (~A & C) ^ B -- There are 4 commuted variants.
48714889
if (match(&I, m_c_Xor(m_OneUse(m_Xor(m_Value(A), m_Value(B))),
48724890
m_OneUse(m_c_Or(m_Deferred(A), m_Value(C))))))
4873-
return BinaryOperator::CreateXor(
4874-
Builder.CreateAnd(Builder.CreateNot(A), C), B);
4891+
return BinaryOperator::CreateXor(Builder.CreateAnd(Builder.CreateNot(A), C),
4892+
B);
48754893

48764894
// (A ^ B) ^ (B | C) --> (~B & C) ^ A -- There are 4 commuted variants.
48774895
if (match(&I, m_c_Xor(m_OneUse(m_Xor(m_Value(A), m_Value(B))),
48784896
m_OneUse(m_c_Or(m_Deferred(B), m_Value(C))))))
4879-
return BinaryOperator::CreateXor(
4880-
Builder.CreateAnd(Builder.CreateNot(B), C), A);
4897+
return BinaryOperator::CreateXor(Builder.CreateAnd(Builder.CreateNot(B), C),
4898+
A);
48814899

48824900
// (A & B) ^ (A ^ B) -> (A | B)
48834901
if (match(Op0, m_And(m_Value(A), m_Value(B))) &&

llvm/test/Transforms/InstCombine/xor.ll

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,9 +1489,7 @@ define i4 @PR96857_xor_without_noundef(i4 %val0, i4 %val1, i4 %val2) {
14891489
define i32 @or_disjoint_with_xor(i32 %a, i32 %b) {
14901490
; CHECK-LABEL: @or_disjoint_with_xor(
14911491
; CHECK-NEXT: entry:
1492-
; CHECK-NEXT: [[TMP0:%.*]] = xor i32 [[A:%.*]], -1
1493-
; CHECK-NEXT: [[XOR:%.*]] = and i32 [[B:%.*]], [[TMP0]]
1494-
; CHECK-NEXT: ret i32 [[XOR]]
1492+
; CHECK-NEXT: ret i32 [[B:%.*]]
14951493
;
14961494
entry:
14971495
%or = or disjoint i32 %a, %b
@@ -1502,9 +1500,7 @@ entry:
15021500
define i32 @xor_with_or_disjoint(i32 %a, i32 %b, i32 %c) {
15031501
; CHECK-LABEL: @xor_with_or_disjoint(
15041502
; CHECK-NEXT: entry:
1505-
; CHECK-NEXT: [[TMP0:%.*]] = xor i32 [[A:%.*]], -1
1506-
; CHECK-NEXT: [[XOR:%.*]] = and i32 [[B:%.*]], [[TMP0]]
1507-
; CHECK-NEXT: ret i32 [[XOR]]
1503+
; CHECK-NEXT: ret i32 [[B:%.*]]
15081504
;
15091505
entry:
15101506
%or = or disjoint i32 %a, %b
@@ -1515,9 +1511,7 @@ entry:
15151511
define <2 x i32> @or_disjoint_with_xor_vec(<2 x i32> %a, < 2 x i32> %b, <2 x i32> %c) {
15161512
; CHECK-LABEL: @or_disjoint_with_xor_vec(
15171513
; CHECK-NEXT: entry:
1518-
; CHECK-NEXT: [[TMP0:%.*]] = xor <2 x i32> [[A:%.*]], <i32 -1, i32 -1>
1519-
; CHECK-NEXT: [[XOR:%.*]] = and <2 x i32> [[B:%.*]], [[TMP0]]
1520-
; CHECK-NEXT: ret <2 x i32> [[XOR]]
1514+
; CHECK-NEXT: ret <2 x i32> [[B:%.*]]
15211515
;
15221516
entry:
15231517
%or = or disjoint <2 x i32> %a, %b
@@ -1528,9 +1522,7 @@ entry:
15281522
define <2 x i32> @xor_with_or_disjoint_vec(<2 x i32> %a, < 2 x i32> %b, <2 x i32> %c) {
15291523
; CHECK-LABEL: @xor_with_or_disjoint_vec(
15301524
; CHECK-NEXT: entry:
1531-
; CHECK-NEXT: [[TMP0:%.*]] = xor <2 x i32> [[A:%.*]], <i32 -1, i32 -1>
1532-
; CHECK-NEXT: [[XOR:%.*]] = and <2 x i32> [[B:%.*]], [[TMP0]]
1533-
; CHECK-NEXT: ret <2 x i32> [[XOR]]
1525+
; CHECK-NEXT: ret <2 x i32> [[B:%.*]]
15341526
;
15351527
entry:
15361528
%or = or disjoint <2 x i32> %a, %b

0 commit comments

Comments
 (0)