Skip to content

Commit 976b1ae

Browse files
committed
[SDAG] Make Select-with-Identity-Fold More Flexible; NFC
This change adds new parameters to the method `shouldFoldSelectWithIdentityConstant()`. The method now takes the opcode of the select node and the non-identity operand of the select node. To gain access to the appropriate arguments, the call of `shouldFoldSelectWithIdentityConstant()` is moved after all other checks have been performed. Moreover, this change adjusts the precondition of the fold so that it would work for `SELECT` nodes in addition to `VSELECT` nodes. No functional change is intended because all implementations of `shouldFoldSelectWithIdentityConstant()` are adjusted such that they restrict the fold to a `VSELECT` node; the same restriction as before. The rationale of this change is to make more fine grained decisions possible when to revert the InstCombine canonicalization of `(select c (binop x y) y)` to `(binop (select c x idc) y)` in the backends.
1 parent b6820c3 commit 976b1ae

File tree

10 files changed

+54
-33
lines changed

10 files changed

+54
-33
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3353,8 +3353,10 @@ class TargetLoweringBase {
33533353
/// Return true if pulling a binary operation into a select with an identity
33543354
/// constant is profitable. This is the inverse of an IR transform.
33553355
/// Example: X + (Cond ? Y : 0) --> Cond ? (X + Y) : X
3356-
virtual bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
3357-
EVT VT) const {
3356+
virtual bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT,
3357+
unsigned SelectOpcode,
3358+
SDValue X,
3359+
SDValue Y) const {
33583360
return false;
33593361
}
33603362

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2425,8 +2425,9 @@ static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
24252425
if (ShouldCommuteOperands)
24262426
std::swap(N0, N1);
24272427

2428-
// TODO: Should this apply to scalar select too?
2429-
if (N1.getOpcode() != ISD::VSELECT || !N1.hasOneUse())
2428+
unsigned SelOpcode = N1.getOpcode();
2429+
if ((SelOpcode != ISD::VSELECT && SelOpcode != ISD::SELECT) ||
2430+
!N1.hasOneUse())
24302431
return SDValue();
24312432

24322433
// We can't hoist all instructions because of immediate UB (not speculatable).
@@ -2439,17 +2440,22 @@ static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
24392440
SDValue Cond = N1.getOperand(0);
24402441
SDValue TVal = N1.getOperand(1);
24412442
SDValue FVal = N1.getOperand(2);
2443+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24422444

24432445
// This transform increases uses of N0, so freeze it to be safe.
24442446
// binop N0, (vselect Cond, IDC, FVal) --> vselect Cond, N0, (binop N0, FVal)
24452447
unsigned OpNo = ShouldCommuteOperands ? 0 : 1;
2446-
if (isNeutralConstant(Opcode, N->getFlags(), TVal, OpNo)) {
2448+
if (isNeutralConstant(Opcode, N->getFlags(), TVal, OpNo) &&
2449+
TLI.shouldFoldSelectWithIdentityConstant(Opcode, VT, SelOpcode, N0,
2450+
FVal)) {
24472451
SDValue F0 = DAG.getFreeze(N0);
24482452
SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, FVal, N->getFlags());
24492453
return DAG.getSelect(SDLoc(N), VT, Cond, F0, NewBO);
24502454
}
24512455
// binop N0, (vselect Cond, TVal, IDC) --> vselect Cond, (binop N0, TVal), N0
2452-
if (isNeutralConstant(Opcode, N->getFlags(), FVal, OpNo)) {
2456+
if (isNeutralConstant(Opcode, N->getFlags(), FVal, OpNo) &&
2457+
TLI.shouldFoldSelectWithIdentityConstant(Opcode, VT, SelOpcode, N0,
2458+
TVal)) {
24532459
SDValue F0 = DAG.getFreeze(N0);
24542460
SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, TVal, N->getFlags());
24552461
return DAG.getSelect(SDLoc(N), VT, Cond, NewBO, F0);
@@ -2459,26 +2465,23 @@ static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
24592465
}
24602466

24612467
SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
2468+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24622469
assert(TLI.isBinOp(BO->getOpcode()) && BO->getNumValues() == 1 &&
24632470
"Unexpected binary operator");
24642471

2465-
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2466-
auto BinOpcode = BO->getOpcode();
2467-
EVT VT = BO->getValueType(0);
2468-
if (TLI.shouldFoldSelectWithIdentityConstant(BinOpcode, VT)) {
2469-
if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, false))
2470-
return Sel;
2472+
if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, false))
2473+
return Sel;
24712474

2472-
if (TLI.isCommutativeBinOp(BO->getOpcode()))
2473-
if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, true))
2474-
return Sel;
2475-
}
2475+
if (TLI.isCommutativeBinOp(BO->getOpcode()))
2476+
if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, true))
2477+
return Sel;
24762478

24772479
// Don't do this unless the old select is going away. We want to eliminate the
24782480
// binary operator, not replace a binop with a select.
24792481
// TODO: Handle ISD::SELECT_CC.
24802482
unsigned SelOpNo = 0;
24812483
SDValue Sel = BO->getOperand(0);
2484+
auto BinOpcode = BO->getOpcode();
24822485
if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) {
24832486
SelOpNo = 1;
24842487
Sel = BO->getOperand(1);
@@ -2526,6 +2529,7 @@ SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
25262529

25272530
SDLoc DL(Sel);
25282531
SDValue NewCT, NewCF;
2532+
EVT VT = BO->getValueType(0);
25292533

25302534
if (CanFoldNonConst) {
25312535
// If CBO is an opaque constant, we can't rely on getNode to constant fold.

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18040,8 +18040,10 @@ bool AArch64TargetLowering::shouldFoldConstantShiftPairToMask(
1804018040
}
1804118041

1804218042
bool AArch64TargetLowering::shouldFoldSelectWithIdentityConstant(
18043-
unsigned BinOpcode, EVT VT) const {
18044-
return VT.isScalableVector() && isTypeLegal(VT);
18043+
unsigned BinOpcode, EVT VT, unsigned SelectOpcode, SDValue X,
18044+
SDValue Y) const {
18045+
return VT.isScalableVector() && isTypeLegal(VT) &&
18046+
SelectOpcode == ISD::VSELECT;
1804518047
}
1804618048

1804718049
bool AArch64TargetLowering::shouldConvertConstantLoadToIntImm(const APInt &Imm,

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -786,8 +786,9 @@ class AArch64TargetLowering : public TargetLowering {
786786
bool shouldFoldConstantShiftPairToMask(const SDNode *N,
787787
CombineLevel Level) const override;
788788

789-
bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
790-
EVT VT) const override;
789+
bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT,
790+
unsigned SelectOpcode, SDValue X,
791+
SDValue Y) const override;
791792

792793
/// Returns true if it is beneficial to convert a load of a constant
793794
/// to just the constant itself.

llvm/lib/Target/ARM/ARMISelLowering.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13960,9 +13960,11 @@ bool ARMTargetLowering::shouldFoldConstantShiftPairToMask(
1396013960
return false;
1396113961
}
1396213962

13963-
bool ARMTargetLowering::shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
13964-
EVT VT) const {
13965-
return Subtarget->hasMVEIntegerOps() && isTypeLegal(VT);
13963+
bool ARMTargetLowering::shouldFoldSelectWithIdentityConstant(
13964+
unsigned BinOpcode, EVT VT, unsigned SelectOpcode, SDValue X,
13965+
SDValue Y) const {
13966+
return Subtarget->hasMVEIntegerOps() && isTypeLegal(VT) &&
13967+
SelectOpcode == ISD::VSELECT;
1396613968
}
1396713969

1396813970
bool ARMTargetLowering::preferIncOfAddToSubOfNot(EVT VT) const {

llvm/lib/Target/ARM/ARMISelLowering.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -758,8 +758,9 @@ class VectorType;
758758
bool shouldFoldConstantShiftPairToMask(const SDNode *N,
759759
CombineLevel Level) const override;
760760

761-
bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
762-
EVT VT) const override;
761+
bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT,
762+
unsigned SelectOpcode, SDValue X,
763+
SDValue Y) const override;
763764

764765
bool preferIncOfAddToSubOfNot(EVT VT) const override;
765766

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2090,8 +2090,12 @@ bool RISCVTargetLowering::hasBitTest(SDValue X, SDValue Y) const {
20902090
return C && C->getAPIntValue().ule(10);
20912091
}
20922092

2093-
bool RISCVTargetLowering::shouldFoldSelectWithIdentityConstant(unsigned Opcode,
2094-
EVT VT) const {
2093+
bool RISCVTargetLowering::shouldFoldSelectWithIdentityConstant(
2094+
unsigned BinOpcode, EVT VT, unsigned SelectOpcode, SDValue X,
2095+
SDValue Y) const {
2096+
if (SelectOpcode != ISD::VSELECT)
2097+
return false;
2098+
20952099
// Only enable for rvv.
20962100
if (!VT.isVector() || !Subtarget.hasVInstructions())
20972101
return false;

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -585,8 +585,9 @@ class RISCVTargetLowering : public TargetLowering {
585585
unsigned &NumIntermediates,
586586
MVT &RegisterVT) const override;
587587

588-
bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
589-
EVT VT) const override;
588+
bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT,
589+
unsigned SelectOpcode, SDValue X,
590+
SDValue Y) const override;
590591

591592
/// Return true if the given shuffle mask can be codegen'd directly, or if it
592593
/// should be stack expanded.

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35383,8 +35383,11 @@ bool X86TargetLowering::isNarrowingProfitable(SDNode *N, EVT SrcVT,
3538335383
return !(SrcVT == MVT::i32 && DestVT == MVT::i16);
3538435384
}
3538535385

35386-
bool X86TargetLowering::shouldFoldSelectWithIdentityConstant(unsigned Opcode,
35387-
EVT VT) const {
35386+
bool X86TargetLowering::shouldFoldSelectWithIdentityConstant(
35387+
unsigned BinOpcode, EVT VT, unsigned SelectOpcode, SDValue X,
35388+
SDValue Y) const {
35389+
if (SelectOpcode != ISD::VSELECT)
35390+
return false;
3538835391
// TODO: This is too general. There are cases where pre-AVX512 codegen would
3538935392
// benefit. The transform may also be profitable for scalar code.
3539035393
if (!Subtarget.hasAVX512())

llvm/lib/Target/X86/X86ISelLowering.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,8 +1460,9 @@ namespace llvm {
14601460
/// from i32 to i16.
14611461
bool isNarrowingProfitable(SDNode *N, EVT SrcVT, EVT DestVT) const override;
14621462

1463-
bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
1464-
EVT VT) const override;
1463+
bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT,
1464+
unsigned SelectOpcode, SDValue X,
1465+
SDValue Y) const override;
14651466

14661467
/// Given an intrinsic, checks if on the target the intrinsic will need to map
14671468
/// to a MemIntrinsicNode (touches memory). If this is the case, it returns

0 commit comments

Comments
 (0)