-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[InstCombine] Extend bitmask mul combine to handle independent operands #142503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e9996d1
171ca6b
f46aaa7
eda45da
15fa10a
132a12b
d32b91d
1909cd9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2367,6 +2367,7 @@ Value *InstCombinerImpl::reassociateBooleanAndOr(Value *LHS, Value *X, Value *Y, | |
Instruction &I, bool IsAnd, | ||
bool RHSIsLogical) { | ||
Instruction::BinaryOps Opcode = IsAnd ? Instruction::And : Instruction::Or; | ||
|
||
// LHS bop (X lop Y) --> (LHS bop X) lop Y | ||
// LHS bop (X bop Y) --> (LHS bop X) bop Y | ||
if (Value *Res = foldBooleanAndOr(LHS, X, I, IsAnd, /*IsLogical=*/false)) | ||
|
@@ -2377,6 +2378,7 @@ Value *InstCombinerImpl::reassociateBooleanAndOr(Value *LHS, Value *X, Value *Y, | |
if (Value *Res = foldBooleanAndOr(LHS, Y, I, IsAnd, /*IsLogical=*/false)) | ||
return RHSIsLogical ? Builder.CreateLogicalOp(Opcode, X, Res) | ||
: Builder.CreateBinOp(Opcode, X, Res); | ||
|
||
return nullptr; | ||
} | ||
|
||
|
@@ -3602,6 +3604,11 @@ struct DecomposedBitMaskMul { | |
APInt Mask; | ||
bool NUW; | ||
bool NSW; | ||
|
||
bool isCombineableWith(DecomposedBitMaskMul Other) { | ||
return X == Other.X && (Mask & Other.Mask).isZero() && | ||
Factor == Other.Factor; | ||
} | ||
}; | ||
|
||
static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) { | ||
|
@@ -3659,6 +3666,106 @@ static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) { | |
return std::nullopt; | ||
} | ||
|
||
// (A & N) * C + (A & M) * C -> (A & (N + M)) & C | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use |
||
// This also accepts the equivalent select form of (A & N) * C | ||
// expressions i.e. !(A & N) ? 0 : N * C) | ||
static Value *foldBitmaskMul(Value *Op0, Value *Op1, | ||
InstCombiner::BuilderTy &Builder) { | ||
auto Decomp1 = matchBitmaskMul(Op1); | ||
|
||
if (Decomp1) { | ||
auto Decomp0 = matchBitmaskMul(Op0); | ||
|
||
if (Decomp0) { | ||
// If we have independent operands in the BitmaskMul chain, then just | ||
// reassociate to encourage combining in future iterations. | ||
|
||
if (Decomp0->isCombineableWith(*Decomp1)) { | ||
auto NewAnd = Builder.CreateAnd( | ||
Decomp0->X, ConstantInt::get(Decomp0->X->getType(), | ||
(Decomp0->Mask + Decomp1->Mask))); | ||
|
||
auto Res = Builder.CreateMul( | ||
NewAnd, ConstantInt::get(NewAnd->getType(), Decomp1->Factor), "", | ||
Decomp0->NUW && Decomp1->NUW, Decomp0->NSW && Decomp1->NSW); | ||
return Res; | ||
} | ||
} | ||
} | ||
|
||
return nullptr; | ||
} | ||
|
||
Value *InstCombinerImpl::foldDisjointOr(Value *LHS, Value *RHS, | ||
Instruction &I) { | ||
if (Value *Res = foldBitmaskMul(LHS, RHS, Builder)) { | ||
return Res; | ||
} | ||
|
||
return nullptr; | ||
} | ||
|
||
Value *InstCombinerImpl::reassociateDisjointOr(Value *LHS, Value *RHS, | ||
Instruction &I) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
Value *X, *Y; | ||
if (match(RHS, m_OneUse(m_DisjointOr(m_Value(X), m_Value(Y))))) { | ||
if (Value *Res = foldDisjointOr(LHS, X, I)) { | ||
auto Disjoint = cast<PossiblyDisjointInst>(Builder.CreateOr(Res, Y)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a parameter |
||
Disjoint->setIsDisjoint(true); | ||
return cast<Value>(Disjoint); | ||
} | ||
if (Value *Res = foldDisjointOr(LHS, Y, I)) { | ||
auto Disjoint = cast<PossiblyDisjointInst>(Builder.CreateOr(Res, X)); | ||
Disjoint->setIsDisjoint(true); | ||
return cast<Value>(Disjoint); | ||
} | ||
} | ||
|
||
if (match(LHS, m_OneUse(m_DisjointOr(m_Value(X), m_Value(Y))))) { | ||
if (Value *Res = foldDisjointOr(X, RHS, I)) { | ||
auto Disjoint = cast<PossiblyDisjointInst>(Builder.CreateOr(Res, Y)); | ||
Disjoint->setIsDisjoint(true); | ||
return cast<Value>(Disjoint); | ||
} | ||
if (Value *Res = foldDisjointOr(Y, RHS, I)) { | ||
auto Disjoint = cast<PossiblyDisjointInst>(Builder.CreateOr(Res, X)); | ||
Disjoint->setIsDisjoint(true); | ||
return cast<Value>(Disjoint); | ||
} | ||
} | ||
|
||
Value *X1, *Y1; | ||
if (match(LHS, m_OneUse(m_DisjointOr(m_Value(X), m_Value(Y)))) && | ||
(match(RHS, m_OneUse(m_DisjointOr(m_Value(X1), m_Value(Y1)))))) { | ||
auto TryFold = [this, &I](Value *Op0, Value *Op1, Value *Rem0, | ||
Value *Rem1) -> Value * { | ||
if (Value *Res = foldDisjointOr(Op0, Op1, I)) { | ||
auto Disjoint = | ||
cast<PossiblyDisjointInst>(Builder.CreateOr(Rem0, Rem1)); | ||
Disjoint->setIsDisjoint(true); | ||
auto Disjoint2 = | ||
cast<PossiblyDisjointInst>(Builder.CreateOr(Disjoint, Res)); | ||
return cast<Value>(Disjoint2); | ||
} | ||
return nullptr; | ||
}; | ||
|
||
if (Value *Res = TryFold(X, X1, Y, Y1)) | ||
return Res; | ||
|
||
if (Value *Res = TryFold(X, Y1, Y, X1)) | ||
return Res; | ||
|
||
if (Value *Res = TryFold(Y, X1, X, Y1)) | ||
return Res; | ||
|
||
if (Value *Res = TryFold(Y, Y1, X, X1)) | ||
return Res; | ||
} | ||
return nullptr; | ||
} | ||
|
||
// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches | ||
// here. We should standardize that construct where it is needed or choose some | ||
// other way to ensure that commutated variants of patterns are not missed. | ||
|
@@ -3741,28 +3848,11 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { | |
/*NSW=*/true, /*NUW=*/true)) | ||
return R; | ||
|
||
// (A & N) * C + (A & M) * C -> (A & (N + M)) & C | ||
// This also accepts the equivalent select form of (A & N) * C | ||
// expressions i.e. !(A & N) ? 0 : N * C) | ||
auto Decomp1 = matchBitmaskMul(I.getOperand(1)); | ||
if (Decomp1) { | ||
auto Decomp0 = matchBitmaskMul(I.getOperand(0)); | ||
if (Decomp0 && Decomp0->X == Decomp1->X && | ||
(Decomp0->Mask & Decomp1->Mask).isZero() && | ||
Decomp0->Factor == Decomp1->Factor) { | ||
|
||
Value *NewAnd = Builder.CreateAnd( | ||
Decomp0->X, ConstantInt::get(Decomp0->X->getType(), | ||
(Decomp0->Mask + Decomp1->Mask))); | ||
|
||
auto *Combined = BinaryOperator::CreateMul( | ||
NewAnd, ConstantInt::get(NewAnd->getType(), Decomp1->Factor)); | ||
if (Value *Res = foldBitmaskMul(I.getOperand(0), I.getOperand(1), Builder)) | ||
return replaceInstUsesWith(I, Res); | ||
|
||
Combined->setHasNoUnsignedWrap(Decomp0->NUW && Decomp1->NUW); | ||
Combined->setHasNoSignedWrap(Decomp0->NSW && Decomp1->NSW); | ||
return Combined; | ||
} | ||
} | ||
if (Value *Res = reassociateDisjointOr(I.getOperand(0), I.getOperand(1), I)) | ||
return replaceInstUsesWith(I, Res); | ||
} | ||
|
||
Value *X, *Y; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unrelated changes.