Skip to content

[InstSimplify] Use multi-op replacement when simplify select #121708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 57 additions & 49 deletions llvm/lib/Analysis/InstructionSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4275,25 +4275,27 @@ Value *llvm::simplifyFCmpInst(CmpPredicate Predicate, Value *LHS, Value *RHS,
return ::simplifyFCmpInst(Predicate, LHS, RHS, FMF, Q, RecursionLimit);
}

static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
const SimplifyQuery &Q,
bool AllowRefinement,
SmallVectorImpl<Instruction *> *DropFlags,
unsigned MaxRecurse) {
static Value *simplifyWithOpsReplaced(Value *V,
ArrayRef<std::pair<Value *, Value *>> Ops,
const SimplifyQuery &Q,
bool AllowRefinement,
SmallVectorImpl<Instruction *> *DropFlags,
unsigned MaxRecurse) {
assert((AllowRefinement || !Q.CanUseUndef) &&
"If AllowRefinement=false then CanUseUndef=false");
for (const auto &OpAndRepOp : Ops) {
// We cannot replace a constant, and shouldn't even try.
if (isa<Constant>(OpAndRepOp.first))
return nullptr;

// Trivial replacement.
if (V == Op)
return RepOp;
// Trivial replacement.
if (V == OpAndRepOp.first)
return OpAndRepOp.second;
}

if (!MaxRecurse--)
return nullptr;

// We cannot replace a constant, and shouldn't even try.
if (isa<Constant>(Op))
return nullptr;

auto *I = dyn_cast<Instruction>(V);
if (!I)
return nullptr;
Expand All @@ -4303,11 +4305,6 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
if (isa<PHINode>(I))
return nullptr;

// For vector types, the simplification must hold per-lane, so forbid
// potentially cross-lane operations like shufflevector.
if (Op->getType()->isVectorTy() && !isNotCrossLaneOperation(I))
return nullptr;

// Don't fold away llvm.is.constant checks based on assumptions.
if (match(I, m_Intrinsic<Intrinsic::is_constant>()))
return nullptr;
Expand All @@ -4316,12 +4313,20 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
if (isa<FreezeInst>(I))
return nullptr;

for (const auto &OpAndRepOp : Ops) {
// For vector types, the simplification must hold per-lane, so forbid
// potentially cross-lane operations like shufflevector.
if (OpAndRepOp.first->getType()->isVectorTy() &&
!isNotCrossLaneOperation(I))
return nullptr;
}

// Replace Op with RepOp in instruction operands.
SmallVector<Value *, 8> NewOps;
bool AnyReplaced = false;
for (Value *InstOp : I->operands()) {
if (Value *NewInstOp = simplifyWithOpReplaced(
InstOp, Op, RepOp, Q, AllowRefinement, DropFlags, MaxRecurse)) {
if (Value *NewInstOp = simplifyWithOpsReplaced(
InstOp, Ops, Q, AllowRefinement, DropFlags, MaxRecurse)) {
NewOps.push_back(NewInstOp);
AnyReplaced = InstOp != NewInstOp;
} else {
Expand Down Expand Up @@ -4372,7 +4377,8 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
// by assumption and this case never wraps, so nowrap flags can be
// ignored.
if ((Opcode == Instruction::Sub || Opcode == Instruction::Xor) &&
NewOps[0] == RepOp && NewOps[1] == RepOp)
NewOps[0] == NewOps[1] &&
any_of(Ops, [=](const auto &Rep) { return NewOps[0] == Rep.second; }))
return Constant::getNullValue(I->getType());

// If we are substituting an absorber constant into a binop and extra
Expand All @@ -4382,10 +4388,10 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
// (Op == 0) ? 0 : (Op & -Op) --> Op & -Op
// (Op == 0) ? 0 : (Op * (binop Op, C)) --> Op * (binop Op, C)
// (Op == -1) ? -1 : (Op | (binop C, Op) --> Op | (binop C, Op)
Constant *Absorber =
ConstantExpr::getBinOpAbsorber(Opcode, I->getType());
Constant *Absorber = ConstantExpr::getBinOpAbsorber(Opcode, I->getType());
if ((NewOps[0] == Absorber || NewOps[1] == Absorber) &&
impliesPoison(BO, Op))
any_of(Ops,
[=](const auto &Rep) { return impliesPoison(BO, Rep.first); }))
return Absorber;
}

Expand Down Expand Up @@ -4453,6 +4459,15 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
/*AllowNonDeterministic=*/false);
}

static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
const SimplifyQuery &Q,
bool AllowRefinement,
SmallVectorImpl<Instruction *> *DropFlags,
unsigned MaxRecurse) {
return simplifyWithOpsReplaced(V, {{Op, RepOp}}, Q, AllowRefinement,
DropFlags, MaxRecurse);
}

Value *llvm::simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
const SimplifyQuery &Q,
bool AllowRefinement,
Expand Down Expand Up @@ -4595,21 +4610,20 @@ static Value *simplifySelectWithFakeICmpEq(Value *CmpLHS, Value *CmpRHS,

/// Try to simplify a select instruction when its condition operand is an
/// integer equality or floating-point equivalence comparison.
static Value *simplifySelectWithEquivalence(Value *CmpLHS, Value *CmpRHS,
Value *TrueVal, Value *FalseVal,
const SimplifyQuery &Q,
unsigned MaxRecurse) {
static Value *simplifySelectWithEquivalence(
ArrayRef<std::pair<Value *, Value *>> Replacements, Value *TrueVal,
Value *FalseVal, const SimplifyQuery &Q, unsigned MaxRecurse) {
Value *SimplifiedFalseVal =
simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q.getWithoutUndef(),
/* AllowRefinement */ false,
/* DropFlags */ nullptr, MaxRecurse);
simplifyWithOpsReplaced(FalseVal, Replacements, Q.getWithoutUndef(),
/* AllowRefinement */ false,
/* DropFlags */ nullptr, MaxRecurse);
if (!SimplifiedFalseVal)
SimplifiedFalseVal = FalseVal;

Value *SimplifiedTrueVal =
simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q,
/* AllowRefinement */ true,
/* DropFlags */ nullptr, MaxRecurse);
simplifyWithOpsReplaced(TrueVal, Replacements, Q,
/* AllowRefinement */ true,
/* DropFlags */ nullptr, MaxRecurse);
if (!SimplifiedTrueVal)
SimplifiedTrueVal = TrueVal;

Expand Down Expand Up @@ -4707,10 +4721,10 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
// the arms of the select. See if substituting this value into the arm and
// simplifying the result yields the same value as the other arm.
if (Pred == ICmpInst::ICMP_EQ) {
if (Value *V = simplifySelectWithEquivalence(CmpLHS, CmpRHS, TrueVal,
if (Value *V = simplifySelectWithEquivalence({{CmpLHS, CmpRHS}}, TrueVal,
FalseVal, Q, MaxRecurse))
return V;
if (Value *V = simplifySelectWithEquivalence(CmpRHS, CmpLHS, TrueVal,
if (Value *V = simplifySelectWithEquivalence({{CmpRHS, CmpLHS}}, TrueVal,
FalseVal, Q, MaxRecurse))
return V;

Expand All @@ -4720,23 +4734,17 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
if (match(CmpLHS, m_Or(m_Value(X), m_Value(Y))) &&
match(CmpRHS, m_Zero())) {
// (X | Y) == 0 implies X == 0 and Y == 0.
if (Value *V = simplifySelectWithEquivalence(X, CmpRHS, TrueVal, FalseVal,
Q, MaxRecurse))
return V;
if (Value *V = simplifySelectWithEquivalence(Y, CmpRHS, TrueVal, FalseVal,
Q, MaxRecurse))
if (Value *V = simplifySelectWithEquivalence(
{{X, CmpRHS}, {Y, CmpRHS}}, TrueVal, FalseVal, Q, MaxRecurse))
return V;
}

// select((X & Y) == -1 ? X : -1) --> -1 (commuted 2 ways)
if (match(CmpLHS, m_And(m_Value(X), m_Value(Y))) &&
match(CmpRHS, m_AllOnes())) {
// (X & Y) == -1 implies X == -1 and Y == -1.
if (Value *V = simplifySelectWithEquivalence(X, CmpRHS, TrueVal, FalseVal,
Q, MaxRecurse))
return V;
if (Value *V = simplifySelectWithEquivalence(Y, CmpRHS, TrueVal, FalseVal,
Q, MaxRecurse))
if (Value *V = simplifySelectWithEquivalence(
{{X, CmpRHS}, {Y, CmpRHS}}, TrueVal, FalseVal, Q, MaxRecurse))
return V;
}
}
Expand Down Expand Up @@ -4765,11 +4773,11 @@ static Value *simplifySelectWithFCmp(Value *Cond, Value *T, Value *F,
// This transforms is safe if at least one operand is known to not be zero.
// Otherwise, the select can change the sign of a zero operand.
if (IsEquiv) {
if (Value *V =
simplifySelectWithEquivalence(CmpLHS, CmpRHS, T, F, Q, MaxRecurse))
if (Value *V = simplifySelectWithEquivalence({{CmpLHS, CmpRHS}}, T, F, Q,
MaxRecurse))
return V;
if (Value *V =
simplifySelectWithEquivalence(CmpRHS, CmpLHS, T, F, Q, MaxRecurse))
if (Value *V = simplifySelectWithEquivalence({{CmpRHS, CmpLHS}}, T, F, Q,
MaxRecurse))
return V;
}

Expand Down
34 changes: 19 additions & 15 deletions llvm/test/Transforms/InstCombine/select.ll
Original file line number Diff line number Diff line change
Expand Up @@ -3937,11 +3937,8 @@ entry:
define i32 @src_or_eq_0_and_xor(i32 %x, i32 %y) {
; CHECK-LABEL: @src_or_eq_0_and_xor(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[OR:%.*]] = or i32 [[Y:%.*]], [[X:%.*]]
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[OR]], 0
; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y]], [[X]]
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 0, i32 [[XOR]]
; CHECK-NEXT: ret i32 [[COND]]
; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]]
; CHECK-NEXT: ret i32 [[XOR]]
;
entry:
%or = or i32 %y, %x
Expand All @@ -3956,11 +3953,8 @@ entry:
define i32 @src_or_eq_0_xor_and(i32 %x, i32 %y) {
; CHECK-LABEL: @src_or_eq_0_xor_and(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[OR:%.*]] = or i32 [[Y:%.*]], [[X:%.*]]
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[OR]], 0
; CHECK-NEXT: [[AND:%.*]] = and i32 [[Y]], [[X]]
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 0, i32 [[AND]]
; CHECK-NEXT: ret i32 [[COND]]
; CHECK-NEXT: [[AND:%.*]] = and i32 [[Y:%.*]], [[X:%.*]]
; CHECK-NEXT: ret i32 [[AND]]
;
entry:
%or = or i32 %y, %x
Expand Down Expand Up @@ -4438,11 +4432,8 @@ define i32 @src_no_trans_select_and_eq0_xor_and(i32 %x, i32 %y) {

define i32 @src_no_trans_select_or_eq0_or_and(i32 %x, i32 %y) {
; CHECK-LABEL: @src_no_trans_select_or_eq0_or_and(
; CHECK-NEXT: [[OR:%.*]] = or i32 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: [[OR0:%.*]] = icmp eq i32 [[OR]], 0
; CHECK-NEXT: [[AND:%.*]] = and i32 [[X]], [[Y]]
; CHECK-NEXT: [[COND:%.*]] = select i1 [[OR0]], i32 0, i32 [[AND]]
; CHECK-NEXT: ret i32 [[COND]]
; CHECK-NEXT: [[AND:%.*]] = and i32 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: ret i32 [[AND]]
;
%or = or i32 %x, %y
%or0 = icmp eq i32 %or, 0
Expand Down Expand Up @@ -4837,3 +4828,16 @@ define i32 @replace_and_cond_multiuse2(i1 %cond1, i1 %cond2) {
%mux = select i1 %cond1, i32 %sel, i32 1
ret i32 %mux
}

define i32 @src_simplify_2x_at_once_and(i32 %x, i32 %y) {
; CHECK-LABEL: @src_simplify_2x_at_once_and(
; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: ret i32 [[XOR]]
;
%and = and i32 %x, %y
%and0 = icmp eq i32 %and, -1
%sub = sub i32 %x, %y
%xor = xor i32 %x, %y
%cond = select i1 %and0, i32 %sub, i32 %xor
ret i32 %cond
}
Loading