Skip to content

Commit 6192faf

Browse files
authored
[InstSimplify] Use multi-op replacement when simplify select (#121708)
- **[InstSimplify] Refactor `simplifyWithOpsReplaced` to allow multiple replacements; NFC** - **[InstSimplify] Use multi-op replacement when simplify `select`** In the case of `select X | Y == 0 :...` or `select X & Y == -1 : ...` we can do more simplifications by trying to replace both `X` and `Y` with the respective constant at once. Handles some cases for #121672 more generically.
1 parent 292c135 commit 6192faf

File tree

2 files changed

+76
-64
lines changed

2 files changed

+76
-64
lines changed

llvm/lib/Analysis/InstructionSimplify.cpp

Lines changed: 57 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -4275,25 +4275,27 @@ Value *llvm::simplifyFCmpInst(CmpPredicate Predicate, Value *LHS, Value *RHS,
42754275
return ::simplifyFCmpInst(Predicate, LHS, RHS, FMF, Q, RecursionLimit);
42764276
}
42774277

4278-
static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
4279-
const SimplifyQuery &Q,
4280-
bool AllowRefinement,
4281-
SmallVectorImpl<Instruction *> *DropFlags,
4282-
unsigned MaxRecurse) {
4278+
static Value *simplifyWithOpsReplaced(Value *V,
4279+
ArrayRef<std::pair<Value *, Value *>> Ops,
4280+
const SimplifyQuery &Q,
4281+
bool AllowRefinement,
4282+
SmallVectorImpl<Instruction *> *DropFlags,
4283+
unsigned MaxRecurse) {
42834284
assert((AllowRefinement || !Q.CanUseUndef) &&
42844285
"If AllowRefinement=false then CanUseUndef=false");
4286+
for (const auto &OpAndRepOp : Ops) {
4287+
// We cannot replace a constant, and shouldn't even try.
4288+
if (isa<Constant>(OpAndRepOp.first))
4289+
return nullptr;
42854290

4286-
// Trivial replacement.
4287-
if (V == Op)
4288-
return RepOp;
4291+
// Trivial replacement.
4292+
if (V == OpAndRepOp.first)
4293+
return OpAndRepOp.second;
4294+
}
42894295

42904296
if (!MaxRecurse--)
42914297
return nullptr;
42924298

4293-
// We cannot replace a constant, and shouldn't even try.
4294-
if (isa<Constant>(Op))
4295-
return nullptr;
4296-
42974299
auto *I = dyn_cast<Instruction>(V);
42984300
if (!I)
42994301
return nullptr;
@@ -4303,11 +4305,6 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
43034305
if (isa<PHINode>(I))
43044306
return nullptr;
43054307

4306-
// For vector types, the simplification must hold per-lane, so forbid
4307-
// potentially cross-lane operations like shufflevector.
4308-
if (Op->getType()->isVectorTy() && !isNotCrossLaneOperation(I))
4309-
return nullptr;
4310-
43114308
// Don't fold away llvm.is.constant checks based on assumptions.
43124309
if (match(I, m_Intrinsic<Intrinsic::is_constant>()))
43134310
return nullptr;
@@ -4316,12 +4313,20 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
43164313
if (isa<FreezeInst>(I))
43174314
return nullptr;
43184315

4316+
for (const auto &OpAndRepOp : Ops) {
4317+
// For vector types, the simplification must hold per-lane, so forbid
4318+
// potentially cross-lane operations like shufflevector.
4319+
if (OpAndRepOp.first->getType()->isVectorTy() &&
4320+
!isNotCrossLaneOperation(I))
4321+
return nullptr;
4322+
}
4323+
43194324
// Replace Op with RepOp in instruction operands.
43204325
SmallVector<Value *, 8> NewOps;
43214326
bool AnyReplaced = false;
43224327
for (Value *InstOp : I->operands()) {
4323-
if (Value *NewInstOp = simplifyWithOpReplaced(
4324-
InstOp, Op, RepOp, Q, AllowRefinement, DropFlags, MaxRecurse)) {
4328+
if (Value *NewInstOp = simplifyWithOpsReplaced(
4329+
InstOp, Ops, Q, AllowRefinement, DropFlags, MaxRecurse)) {
43254330
NewOps.push_back(NewInstOp);
43264331
AnyReplaced = InstOp != NewInstOp;
43274332
} else {
@@ -4372,7 +4377,8 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
43724377
// by assumption and this case never wraps, so nowrap flags can be
43734378
// ignored.
43744379
if ((Opcode == Instruction::Sub || Opcode == Instruction::Xor) &&
4375-
NewOps[0] == RepOp && NewOps[1] == RepOp)
4380+
NewOps[0] == NewOps[1] &&
4381+
any_of(Ops, [=](const auto &Rep) { return NewOps[0] == Rep.second; }))
43764382
return Constant::getNullValue(I->getType());
43774383

43784384
// If we are substituting an absorber constant into a binop and extra
@@ -4382,10 +4388,10 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
43824388
// (Op == 0) ? 0 : (Op & -Op) --> Op & -Op
43834389
// (Op == 0) ? 0 : (Op * (binop Op, C)) --> Op * (binop Op, C)
43844390
// (Op == -1) ? -1 : (Op | (binop C, Op) --> Op | (binop C, Op)
4385-
Constant *Absorber =
4386-
ConstantExpr::getBinOpAbsorber(Opcode, I->getType());
4391+
Constant *Absorber = ConstantExpr::getBinOpAbsorber(Opcode, I->getType());
43874392
if ((NewOps[0] == Absorber || NewOps[1] == Absorber) &&
4388-
impliesPoison(BO, Op))
4393+
any_of(Ops,
4394+
[=](const auto &Rep) { return impliesPoison(BO, Rep.first); }))
43894395
return Absorber;
43904396
}
43914397

@@ -4453,6 +4459,15 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
44534459
/*AllowNonDeterministic=*/false);
44544460
}
44554461

4462+
static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
4463+
const SimplifyQuery &Q,
4464+
bool AllowRefinement,
4465+
SmallVectorImpl<Instruction *> *DropFlags,
4466+
unsigned MaxRecurse) {
4467+
return simplifyWithOpsReplaced(V, {{Op, RepOp}}, Q, AllowRefinement,
4468+
DropFlags, MaxRecurse);
4469+
}
4470+
44564471
Value *llvm::simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
44574472
const SimplifyQuery &Q,
44584473
bool AllowRefinement,
@@ -4595,21 +4610,20 @@ static Value *simplifySelectWithFakeICmpEq(Value *CmpLHS, Value *CmpRHS,
45954610

45964611
/// Try to simplify a select instruction when its condition operand is an
45974612
/// integer equality or floating-point equivalence comparison.
4598-
static Value *simplifySelectWithEquivalence(Value *CmpLHS, Value *CmpRHS,
4599-
Value *TrueVal, Value *FalseVal,
4600-
const SimplifyQuery &Q,
4601-
unsigned MaxRecurse) {
4613+
static Value *simplifySelectWithEquivalence(
4614+
ArrayRef<std::pair<Value *, Value *>> Replacements, Value *TrueVal,
4615+
Value *FalseVal, const SimplifyQuery &Q, unsigned MaxRecurse) {
46024616
Value *SimplifiedFalseVal =
4603-
simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q.getWithoutUndef(),
4604-
/* AllowRefinement */ false,
4605-
/* DropFlags */ nullptr, MaxRecurse);
4617+
simplifyWithOpsReplaced(FalseVal, Replacements, Q.getWithoutUndef(),
4618+
/* AllowRefinement */ false,
4619+
/* DropFlags */ nullptr, MaxRecurse);
46064620
if (!SimplifiedFalseVal)
46074621
SimplifiedFalseVal = FalseVal;
46084622

46094623
Value *SimplifiedTrueVal =
4610-
simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q,
4611-
/* AllowRefinement */ true,
4612-
/* DropFlags */ nullptr, MaxRecurse);
4624+
simplifyWithOpsReplaced(TrueVal, Replacements, Q,
4625+
/* AllowRefinement */ true,
4626+
/* DropFlags */ nullptr, MaxRecurse);
46134627
if (!SimplifiedTrueVal)
46144628
SimplifiedTrueVal = TrueVal;
46154629

@@ -4707,10 +4721,10 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
47074721
// the arms of the select. See if substituting this value into the arm and
47084722
// simplifying the result yields the same value as the other arm.
47094723
if (Pred == ICmpInst::ICMP_EQ) {
4710-
if (Value *V = simplifySelectWithEquivalence(CmpLHS, CmpRHS, TrueVal,
4724+
if (Value *V = simplifySelectWithEquivalence({{CmpLHS, CmpRHS}}, TrueVal,
47114725
FalseVal, Q, MaxRecurse))
47124726
return V;
4713-
if (Value *V = simplifySelectWithEquivalence(CmpRHS, CmpLHS, TrueVal,
4727+
if (Value *V = simplifySelectWithEquivalence({{CmpRHS, CmpLHS}}, TrueVal,
47144728
FalseVal, Q, MaxRecurse))
47154729
return V;
47164730

@@ -4720,23 +4734,17 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
47204734
if (match(CmpLHS, m_Or(m_Value(X), m_Value(Y))) &&
47214735
match(CmpRHS, m_Zero())) {
47224736
// (X | Y) == 0 implies X == 0 and Y == 0.
4723-
if (Value *V = simplifySelectWithEquivalence(X, CmpRHS, TrueVal, FalseVal,
4724-
Q, MaxRecurse))
4725-
return V;
4726-
if (Value *V = simplifySelectWithEquivalence(Y, CmpRHS, TrueVal, FalseVal,
4727-
Q, MaxRecurse))
4737+
if (Value *V = simplifySelectWithEquivalence(
4738+
{{X, CmpRHS}, {Y, CmpRHS}}, TrueVal, FalseVal, Q, MaxRecurse))
47284739
return V;
47294740
}
47304741

47314742
// select((X & Y) == -1 ? X : -1) --> -1 (commuted 2 ways)
47324743
if (match(CmpLHS, m_And(m_Value(X), m_Value(Y))) &&
47334744
match(CmpRHS, m_AllOnes())) {
47344745
// (X & Y) == -1 implies X == -1 and Y == -1.
4735-
if (Value *V = simplifySelectWithEquivalence(X, CmpRHS, TrueVal, FalseVal,
4736-
Q, MaxRecurse))
4737-
return V;
4738-
if (Value *V = simplifySelectWithEquivalence(Y, CmpRHS, TrueVal, FalseVal,
4739-
Q, MaxRecurse))
4746+
if (Value *V = simplifySelectWithEquivalence(
4747+
{{X, CmpRHS}, {Y, CmpRHS}}, TrueVal, FalseVal, Q, MaxRecurse))
47404748
return V;
47414749
}
47424750
}
@@ -4765,11 +4773,11 @@ static Value *simplifySelectWithFCmp(Value *Cond, Value *T, Value *F,
47654773
// This transforms is safe if at least one operand is known to not be zero.
47664774
// Otherwise, the select can change the sign of a zero operand.
47674775
if (IsEquiv) {
4768-
if (Value *V =
4769-
simplifySelectWithEquivalence(CmpLHS, CmpRHS, T, F, Q, MaxRecurse))
4776+
if (Value *V = simplifySelectWithEquivalence({{CmpLHS, CmpRHS}}, T, F, Q,
4777+
MaxRecurse))
47704778
return V;
4771-
if (Value *V =
4772-
simplifySelectWithEquivalence(CmpRHS, CmpLHS, T, F, Q, MaxRecurse))
4779+
if (Value *V = simplifySelectWithEquivalence({{CmpRHS, CmpLHS}}, T, F, Q,
4780+
MaxRecurse))
47734781
return V;
47744782
}
47754783

llvm/test/Transforms/InstCombine/select.ll

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3937,11 +3937,8 @@ entry:
39373937
define i32 @src_or_eq_0_and_xor(i32 %x, i32 %y) {
39383938
; CHECK-LABEL: @src_or_eq_0_and_xor(
39393939
; CHECK-NEXT: entry:
3940-
; CHECK-NEXT: [[OR:%.*]] = or i32 [[Y:%.*]], [[X:%.*]]
3941-
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[OR]], 0
3942-
; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y]], [[X]]
3943-
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 0, i32 [[XOR]]
3944-
; CHECK-NEXT: ret i32 [[COND]]
3940+
; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]]
3941+
; CHECK-NEXT: ret i32 [[XOR]]
39453942
;
39463943
entry:
39473944
%or = or i32 %y, %x
@@ -3956,11 +3953,8 @@ entry:
39563953
define i32 @src_or_eq_0_xor_and(i32 %x, i32 %y) {
39573954
; CHECK-LABEL: @src_or_eq_0_xor_and(
39583955
; CHECK-NEXT: entry:
3959-
; CHECK-NEXT: [[OR:%.*]] = or i32 [[Y:%.*]], [[X:%.*]]
3960-
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[OR]], 0
3961-
; CHECK-NEXT: [[AND:%.*]] = and i32 [[Y]], [[X]]
3962-
; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 0, i32 [[AND]]
3963-
; CHECK-NEXT: ret i32 [[COND]]
3956+
; CHECK-NEXT: [[AND:%.*]] = and i32 [[Y:%.*]], [[X:%.*]]
3957+
; CHECK-NEXT: ret i32 [[AND]]
39643958
;
39653959
entry:
39663960
%or = or i32 %y, %x
@@ -4438,11 +4432,8 @@ define i32 @src_no_trans_select_and_eq0_xor_and(i32 %x, i32 %y) {
44384432

44394433
define i32 @src_no_trans_select_or_eq0_or_and(i32 %x, i32 %y) {
44404434
; CHECK-LABEL: @src_no_trans_select_or_eq0_or_and(
4441-
; CHECK-NEXT: [[OR:%.*]] = or i32 [[X:%.*]], [[Y:%.*]]
4442-
; CHECK-NEXT: [[OR0:%.*]] = icmp eq i32 [[OR]], 0
4443-
; CHECK-NEXT: [[AND:%.*]] = and i32 [[X]], [[Y]]
4444-
; CHECK-NEXT: [[COND:%.*]] = select i1 [[OR0]], i32 0, i32 [[AND]]
4445-
; CHECK-NEXT: ret i32 [[COND]]
4435+
; CHECK-NEXT: [[AND:%.*]] = and i32 [[X:%.*]], [[Y:%.*]]
4436+
; CHECK-NEXT: ret i32 [[AND]]
44464437
;
44474438
%or = or i32 %x, %y
44484439
%or0 = icmp eq i32 %or, 0
@@ -4837,3 +4828,16 @@ define i32 @replace_and_cond_multiuse2(i1 %cond1, i1 %cond2) {
48374828
%mux = select i1 %cond1, i32 %sel, i32 1
48384829
ret i32 %mux
48394830
}
4831+
4832+
define i32 @src_simplify_2x_at_once_and(i32 %x, i32 %y) {
4833+
; CHECK-LABEL: @src_simplify_2x_at_once_and(
4834+
; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[X:%.*]], [[Y:%.*]]
4835+
; CHECK-NEXT: ret i32 [[XOR]]
4836+
;
4837+
%and = and i32 %x, %y
4838+
%and0 = icmp eq i32 %and, -1
4839+
%sub = sub i32 %x, %y
4840+
%xor = xor i32 %x, %y
4841+
%cond = select i1 %and0, i32 %sub, i32 %xor
4842+
ret i32 %cond
4843+
}

0 commit comments

Comments
 (0)