Skip to content

[InstCombine] Extend bitmask mul combine to handle independent operands #142503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 111 additions & 21 deletions llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2367,6 +2367,7 @@ Value *InstCombinerImpl::reassociateBooleanAndOr(Value *LHS, Value *X, Value *Y,
Instruction &I, bool IsAnd,
bool RHSIsLogical) {
Instruction::BinaryOps Opcode = IsAnd ? Instruction::And : Instruction::Or;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unrelated changes.

// LHS bop (X lop Y) --> (LHS bop X) lop Y
// LHS bop (X bop Y) --> (LHS bop X) bop Y
if (Value *Res = foldBooleanAndOr(LHS, X, I, IsAnd, /*IsLogical=*/false))
Expand All @@ -2377,6 +2378,7 @@ Value *InstCombinerImpl::reassociateBooleanAndOr(Value *LHS, Value *X, Value *Y,
if (Value *Res = foldBooleanAndOr(LHS, Y, I, IsAnd, /*IsLogical=*/false))
return RHSIsLogical ? Builder.CreateLogicalOp(Opcode, X, Res)
: Builder.CreateBinOp(Opcode, X, Res);

return nullptr;
}

Expand Down Expand Up @@ -3602,6 +3604,11 @@ struct DecomposedBitMaskMul {
APInt Mask;
bool NUW;
bool NSW;

bool isCombineableWith(DecomposedBitMaskMul Other) {
return X == Other.X && (Mask & Other.Mask).isZero() &&
Factor == Other.Factor;
}
};

static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
Expand Down Expand Up @@ -3659,6 +3666,106 @@ static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
return std::nullopt;
}

// (A & N) * C + (A & M) * C -> (A & (N + M)) & C
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use ///

// This also accepts the equivalent select form of (A & N) * C
// expressions i.e. !(A & N) ? 0 : N * C)
static Value *foldBitmaskMul(Value *Op0, Value *Op1,
InstCombiner::BuilderTy &Builder) {
auto Decomp1 = matchBitmaskMul(Op1);

if (Decomp1) {
auto Decomp0 = matchBitmaskMul(Op0);

if (Decomp0) {
// If we have independent operands in the BitmaskMul chain, then just
// reassociate to encourage combining in future iterations.

if (Decomp0->isCombineableWith(*Decomp1)) {
auto NewAnd = Builder.CreateAnd(
Decomp0->X, ConstantInt::get(Decomp0->X->getType(),
(Decomp0->Mask + Decomp1->Mask)));

auto Res = Builder.CreateMul(
NewAnd, ConstantInt::get(NewAnd->getType(), Decomp1->Factor), "",
Decomp0->NUW && Decomp1->NUW, Decomp0->NSW && Decomp1->NSW);
return Res;
}
}
}

return nullptr;
}

Value *InstCombinerImpl::foldDisjointOr(Value *LHS, Value *RHS,
Instruction &I) {
if (Value *Res = foldBitmaskMul(LHS, RHS, Builder)) {
return Res;
}

return nullptr;
}

Value *InstCombinerImpl::reassociateDisjointOr(Value *LHS, Value *RHS,
Instruction &I) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instruction &I is unused.


Value *X, *Y;
if (match(RHS, m_OneUse(m_DisjointOr(m_Value(X), m_Value(Y))))) {
if (Value *Res = foldDisjointOr(LHS, X, I)) {
auto Disjoint = cast<PossiblyDisjointInst>(Builder.CreateOr(Res, Y));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a parameter IsDisjoint = false to IRBuilder::CreateOr?

Disjoint->setIsDisjoint(true);
return cast<Value>(Disjoint);
}
if (Value *Res = foldDisjointOr(LHS, Y, I)) {
auto Disjoint = cast<PossiblyDisjointInst>(Builder.CreateOr(Res, X));
Disjoint->setIsDisjoint(true);
return cast<Value>(Disjoint);
}
}

if (match(LHS, m_OneUse(m_DisjointOr(m_Value(X), m_Value(Y))))) {
if (Value *Res = foldDisjointOr(X, RHS, I)) {
auto Disjoint = cast<PossiblyDisjointInst>(Builder.CreateOr(Res, Y));
Disjoint->setIsDisjoint(true);
return cast<Value>(Disjoint);
}
if (Value *Res = foldDisjointOr(Y, RHS, I)) {
auto Disjoint = cast<PossiblyDisjointInst>(Builder.CreateOr(Res, X));
Disjoint->setIsDisjoint(true);
return cast<Value>(Disjoint);
}
}

Value *X1, *Y1;
if (match(LHS, m_OneUse(m_DisjointOr(m_Value(X), m_Value(Y)))) &&
(match(RHS, m_OneUse(m_DisjointOr(m_Value(X1), m_Value(Y1)))))) {
auto TryFold = [this, &I](Value *Op0, Value *Op1, Value *Rem0,
Value *Rem1) -> Value * {
if (Value *Res = foldDisjointOr(Op0, Op1, I)) {
auto Disjoint =
cast<PossiblyDisjointInst>(Builder.CreateOr(Rem0, Rem1));
Disjoint->setIsDisjoint(true);
auto Disjoint2 =
cast<PossiblyDisjointInst>(Builder.CreateOr(Disjoint, Res));
return cast<Value>(Disjoint2);
}
return nullptr;
};

if (Value *Res = TryFold(X, X1, Y, Y1))
return Res;

if (Value *Res = TryFold(X, Y1, Y, X1))
return Res;

if (Value *Res = TryFold(Y, X1, X, Y1))
return Res;

if (Value *Res = TryFold(Y, Y1, X, X1))
return Res;
}
return nullptr;
}

// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
// here. We should standardize that construct where it is needed or choose some
// other way to ensure that commutated variants of patterns are not missed.
Expand Down Expand Up @@ -3741,28 +3848,11 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
/*NSW=*/true, /*NUW=*/true))
return R;

// (A & N) * C + (A & M) * C -> (A & (N + M)) & C
// This also accepts the equivalent select form of (A & N) * C
// expressions i.e. !(A & N) ? 0 : N * C)
auto Decomp1 = matchBitmaskMul(I.getOperand(1));
if (Decomp1) {
auto Decomp0 = matchBitmaskMul(I.getOperand(0));
if (Decomp0 && Decomp0->X == Decomp1->X &&
(Decomp0->Mask & Decomp1->Mask).isZero() &&
Decomp0->Factor == Decomp1->Factor) {

Value *NewAnd = Builder.CreateAnd(
Decomp0->X, ConstantInt::get(Decomp0->X->getType(),
(Decomp0->Mask + Decomp1->Mask)));

auto *Combined = BinaryOperator::CreateMul(
NewAnd, ConstantInt::get(NewAnd->getType(), Decomp1->Factor));
if (Value *Res = foldBitmaskMul(I.getOperand(0), I.getOperand(1), Builder))
return replaceInstUsesWith(I, Res);

Combined->setHasNoUnsignedWrap(Decomp0->NUW && Decomp1->NUW);
Combined->setHasNoSignedWrap(Decomp0->NSW && Decomp1->NSW);
return Combined;
}
}
if (Value *Res = reassociateDisjointOr(I.getOperand(0), I.getOperand(1), I))
return replaceInstUsesWith(I, Res);
}

Value *X, *Y;
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
Value *reassociateBooleanAndOr(Value *LHS, Value *X, Value *Y, Instruction &I,
bool IsAnd, bool RHSIsLogical);

Value *foldDisjointOr(Value *LHS, Value *RHS, Instruction &I);

Value *reassociateDisjointOr(Value *LHS, Value *RHS, Instruction &I);

Instruction *
canonicalizeConditionalNegationViaMathToSelect(BinaryOperator &i);

Expand Down
Loading