Skip to content

Commit 90ebc20

Browse files
mskampsvkeerthy
authored andcommitted
[SDAG] Make Select-with-Identity-Fold More Flexible; NFC (#136554)
This change adds new parameters to the method `shouldFoldSelectWithIdentityConstant()`. The method now takes the opcode of the select node and the non-identity operand of the select node. To gain access to the appropriate arguments, the call of `shouldFoldSelectWithIdentityConstant()` is moved after all other checks have been performed. Moreover, this change adjusts the precondition of the fold so that it would work for `SELECT` nodes in addition to `VSELECT` nodes. No functional change is intended because all implementations of `shouldFoldSelectWithIdentityConstant()` are adjusted such that they restrict the fold to a `VSELECT` node; the same restriction as before. The rationale of this change is to make more fine grained decisions possible when to revert the InstCombine canonicalization of `(select c (binop x y) y)` to `(binop (select c x idc) y)` in the backends.
1 parent 035d639 commit 90ebc20

File tree

10 files changed

+54
-33
lines changed

10 files changed

+54
-33
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3390,8 +3390,10 @@ class LLVM_ABI TargetLoweringBase {
33903390
/// Return true if pulling a binary operation into a select with an identity
33913391
/// constant is profitable. This is the inverse of an IR transform.
33923392
/// Example: X + (Cond ? Y : 0) --> Cond ? (X + Y) : X
3393-
virtual bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
3394-
EVT VT) const {
3393+
virtual bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT,
3394+
unsigned SelectOpcode,
3395+
SDValue X,
3396+
SDValue Y) const {
33953397
return false;
33963398
}
33973399

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2433,8 +2433,9 @@ static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
24332433
if (ShouldCommuteOperands)
24342434
std::swap(N0, N1);
24352435

2436-
// TODO: Should this apply to scalar select too?
2437-
if (N1.getOpcode() != ISD::VSELECT || !N1.hasOneUse())
2436+
unsigned SelOpcode = N1.getOpcode();
2437+
if ((SelOpcode != ISD::VSELECT && SelOpcode != ISD::SELECT) ||
2438+
!N1.hasOneUse())
24382439
return SDValue();
24392440

24402441
// We can't hoist all instructions because of immediate UB (not speculatable).
@@ -2447,17 +2448,22 @@ static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
24472448
SDValue Cond = N1.getOperand(0);
24482449
SDValue TVal = N1.getOperand(1);
24492450
SDValue FVal = N1.getOperand(2);
2451+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24502452

24512453
// This transform increases uses of N0, so freeze it to be safe.
24522454
// binop N0, (vselect Cond, IDC, FVal) --> vselect Cond, N0, (binop N0, FVal)
24532455
unsigned OpNo = ShouldCommuteOperands ? 0 : 1;
2454-
if (isNeutralConstant(Opcode, N->getFlags(), TVal, OpNo)) {
2456+
if (isNeutralConstant(Opcode, N->getFlags(), TVal, OpNo) &&
2457+
TLI.shouldFoldSelectWithIdentityConstant(Opcode, VT, SelOpcode, N0,
2458+
FVal)) {
24552459
SDValue F0 = DAG.getFreeze(N0);
24562460
SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, FVal, N->getFlags());
24572461
return DAG.getSelect(SDLoc(N), VT, Cond, F0, NewBO);
24582462
}
24592463
// binop N0, (vselect Cond, TVal, IDC) --> vselect Cond, (binop N0, TVal), N0
2460-
if (isNeutralConstant(Opcode, N->getFlags(), FVal, OpNo)) {
2464+
if (isNeutralConstant(Opcode, N->getFlags(), FVal, OpNo) &&
2465+
TLI.shouldFoldSelectWithIdentityConstant(Opcode, VT, SelOpcode, N0,
2466+
TVal)) {
24612467
SDValue F0 = DAG.getFreeze(N0);
24622468
SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, TVal, N->getFlags());
24632469
return DAG.getSelect(SDLoc(N), VT, Cond, NewBO, F0);
@@ -2467,26 +2473,23 @@ static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
24672473
}
24682474

24692475
SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
2476+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
24702477
assert(TLI.isBinOp(BO->getOpcode()) && BO->getNumValues() == 1 &&
24712478
"Unexpected binary operator");
24722479

2473-
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2474-
auto BinOpcode = BO->getOpcode();
2475-
EVT VT = BO->getValueType(0);
2476-
if (TLI.shouldFoldSelectWithIdentityConstant(BinOpcode, VT)) {
2477-
if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, false))
2478-
return Sel;
2480+
if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, false))
2481+
return Sel;
24792482

2480-
if (TLI.isCommutativeBinOp(BO->getOpcode()))
2481-
if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, true))
2482-
return Sel;
2483-
}
2483+
if (TLI.isCommutativeBinOp(BO->getOpcode()))
2484+
if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, true))
2485+
return Sel;
24842486

24852487
// Don't do this unless the old select is going away. We want to eliminate the
24862488
// binary operator, not replace a binop with a select.
24872489
// TODO: Handle ISD::SELECT_CC.
24882490
unsigned SelOpNo = 0;
24892491
SDValue Sel = BO->getOperand(0);
2492+
auto BinOpcode = BO->getOpcode();
24902493
if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) {
24912494
SelOpNo = 1;
24922495
Sel = BO->getOperand(1);
@@ -2534,6 +2537,7 @@ SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
25342537

25352538
SDLoc DL(Sel);
25362539
SDValue NewCT, NewCF;
2540+
EVT VT = BO->getValueType(0);
25372541

25382542
if (CanFoldNonConst) {
25392543
// If CBO is an opaque constant, we can't rely on getNode to constant fold.

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17685,8 +17685,10 @@ bool AArch64TargetLowering::shouldFoldConstantShiftPairToMask(
1768517685
}
1768617686

1768717687
bool AArch64TargetLowering::shouldFoldSelectWithIdentityConstant(
17688-
unsigned BinOpcode, EVT VT) const {
17689-
return VT.isScalableVector() && isTypeLegal(VT);
17688+
unsigned BinOpcode, EVT VT, unsigned SelectOpcode, SDValue X,
17689+
SDValue Y) const {
17690+
return VT.isScalableVector() && isTypeLegal(VT) &&
17691+
SelectOpcode == ISD::VSELECT;
1769017692
}
1769117693

1769217694
bool AArch64TargetLowering::shouldConvertConstantLoadToIntImm(const APInt &Imm,

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,9 @@ class AArch64TargetLowering : public TargetLowering {
281281
bool shouldFoldConstantShiftPairToMask(const SDNode *N,
282282
CombineLevel Level) const override;
283283

284-
bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
285-
EVT VT) const override;
284+
bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT,
285+
unsigned SelectOpcode, SDValue X,
286+
SDValue Y) const override;
286287

287288
/// Returns true if it is beneficial to convert a load of a constant
288289
/// to just the constant itself.

llvm/lib/Target/ARM/ARMISelLowering.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13957,9 +13957,11 @@ bool ARMTargetLowering::shouldFoldConstantShiftPairToMask(
1395713957
return false;
1395813958
}
1395913959

13960-
bool ARMTargetLowering::shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
13961-
EVT VT) const {
13962-
return Subtarget->hasMVEIntegerOps() && isTypeLegal(VT);
13960+
bool ARMTargetLowering::shouldFoldSelectWithIdentityConstant(
13961+
unsigned BinOpcode, EVT VT, unsigned SelectOpcode, SDValue X,
13962+
SDValue Y) const {
13963+
return Subtarget->hasMVEIntegerOps() && isTypeLegal(VT) &&
13964+
SelectOpcode == ISD::VSELECT;
1396313965
}
1396413966

1396513967
bool ARMTargetLowering::preferIncOfAddToSubOfNot(EVT VT) const {

llvm/lib/Target/ARM/ARMISelLowering.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -758,8 +758,9 @@ class VectorType;
758758
bool shouldFoldConstantShiftPairToMask(const SDNode *N,
759759
CombineLevel Level) const override;
760760

761-
bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
762-
EVT VT) const override;
761+
bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT,
762+
unsigned SelectOpcode, SDValue X,
763+
SDValue Y) const override;
763764

764765
bool preferIncOfAddToSubOfNot(EVT VT) const override;
765766

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2145,8 +2145,12 @@ bool RISCVTargetLowering::hasBitTest(SDValue X, SDValue Y) const {
21452145
return C && C->getAPIntValue().ule(10);
21462146
}
21472147

2148-
bool RISCVTargetLowering::shouldFoldSelectWithIdentityConstant(unsigned Opcode,
2149-
EVT VT) const {
2148+
bool RISCVTargetLowering::shouldFoldSelectWithIdentityConstant(
2149+
unsigned BinOpcode, EVT VT, unsigned SelectOpcode, SDValue X,
2150+
SDValue Y) const {
2151+
if (SelectOpcode != ISD::VSELECT)
2152+
return false;
2153+
21502154
// Only enable for rvv.
21512155
if (!VT.isVector() || !Subtarget.hasVInstructions())
21522156
return false;

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,9 @@ class RISCVTargetLowering : public TargetLowering {
9595
unsigned &NumIntermediates,
9696
MVT &RegisterVT) const override;
9797

98-
bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
99-
EVT VT) const override;
98+
bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT,
99+
unsigned SelectOpcode, SDValue X,
100+
SDValue Y) const override;
100101

101102
/// Return true if the given shuffle mask can be codegen'd directly, or if it
102103
/// should be stack expanded.

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35549,8 +35549,11 @@ bool X86TargetLowering::isNarrowingProfitable(SDNode *N, EVT SrcVT,
3554935549
return !(SrcVT == MVT::i32 && DestVT == MVT::i16);
3555035550
}
3555135551

35552-
bool X86TargetLowering::shouldFoldSelectWithIdentityConstant(unsigned Opcode,
35553-
EVT VT) const {
35552+
bool X86TargetLowering::shouldFoldSelectWithIdentityConstant(
35553+
unsigned BinOpcode, EVT VT, unsigned SelectOpcode, SDValue X,
35554+
SDValue Y) const {
35555+
if (SelectOpcode != ISD::VSELECT)
35556+
return false;
3555435557
// TODO: This is too general. There are cases where pre-AVX512 codegen would
3555535558
// benefit. The transform may also be profitable for scalar code.
3555635559
if (!Subtarget.hasAVX512())

llvm/lib/Target/X86/X86ISelLowering.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1467,8 +1467,9 @@ namespace llvm {
14671467
/// from i32 to i16.
14681468
bool isNarrowingProfitable(SDNode *N, EVT SrcVT, EVT DestVT) const override;
14691469

1470-
bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
1471-
EVT VT) const override;
1470+
bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT,
1471+
unsigned SelectOpcode, SDValue X,
1472+
SDValue Y) const override;
14721473

14731474
/// Given an intrinsic, checks if on the target the intrinsic will need to map
14741475
/// to a MemIntrinsicNode (touches memory). If this is the case, it returns

0 commit comments

Comments
 (0)