Skip to content

[SDAG] Make Select-with-Identity-Fold More Flexible; NFC #136554

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -3353,8 +3353,10 @@ class TargetLoweringBase {
/// Return true if pulling a binary operation into a select with an identity
/// constant is profitable. This is the inverse of an IR transform.
/// Example: X + (Cond ? Y : 0) --> Cond ? (X + Y) : X
virtual bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
EVT VT) const {
virtual bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT,
unsigned SelectOpcode,
SDValue X,
SDValue Y) const {
return false;
}

Expand Down
32 changes: 18 additions & 14 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2425,8 +2425,9 @@ static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
if (ShouldCommuteOperands)
std::swap(N0, N1);

// TODO: Should this apply to scalar select too?
if (N1.getOpcode() != ISD::VSELECT || !N1.hasOneUse())
unsigned SelOpcode = N1.getOpcode();
if ((SelOpcode != ISD::VSELECT && SelOpcode != ISD::SELECT) ||
!N1.hasOneUse())
return SDValue();

// We can't hoist all instructions because of immediate UB (not speculatable).
Expand All @@ -2439,17 +2440,22 @@ static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
SDValue Cond = N1.getOperand(0);
SDValue TVal = N1.getOperand(1);
SDValue FVal = N1.getOperand(2);
const TargetLowering &TLI = DAG.getTargetLoweringInfo();

// This transform increases uses of N0, so freeze it to be safe.
// binop N0, (vselect Cond, IDC, FVal) --> vselect Cond, N0, (binop N0, FVal)
unsigned OpNo = ShouldCommuteOperands ? 0 : 1;
if (isNeutralConstant(Opcode, N->getFlags(), TVal, OpNo)) {
if (isNeutralConstant(Opcode, N->getFlags(), TVal, OpNo) &&
TLI.shouldFoldSelectWithIdentityConstant(Opcode, VT, SelOpcode, N0,
FVal)) {
SDValue F0 = DAG.getFreeze(N0);
SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, FVal, N->getFlags());
return DAG.getSelect(SDLoc(N), VT, Cond, F0, NewBO);
}
// binop N0, (vselect Cond, TVal, IDC) --> vselect Cond, (binop N0, TVal), N0
if (isNeutralConstant(Opcode, N->getFlags(), FVal, OpNo)) {
if (isNeutralConstant(Opcode, N->getFlags(), FVal, OpNo) &&
TLI.shouldFoldSelectWithIdentityConstant(Opcode, VT, SelOpcode, N0,
TVal)) {
SDValue F0 = DAG.getFreeze(N0);
SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, TVal, N->getFlags());
return DAG.getSelect(SDLoc(N), VT, Cond, NewBO, F0);
Expand All @@ -2459,26 +2465,23 @@ static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
}

SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
assert(TLI.isBinOp(BO->getOpcode()) && BO->getNumValues() == 1 &&
"Unexpected binary operator");

const TargetLowering &TLI = DAG.getTargetLoweringInfo();
auto BinOpcode = BO->getOpcode();
EVT VT = BO->getValueType(0);
if (TLI.shouldFoldSelectWithIdentityConstant(BinOpcode, VT)) {
if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, false))
return Sel;
if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, false))
return Sel;

if (TLI.isCommutativeBinOp(BO->getOpcode()))
if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, true))
return Sel;
}
if (TLI.isCommutativeBinOp(BO->getOpcode()))
if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, true))
return Sel;

// Don't do this unless the old select is going away. We want to eliminate the
// binary operator, not replace a binop with a select.
// TODO: Handle ISD::SELECT_CC.
unsigned SelOpNo = 0;
SDValue Sel = BO->getOperand(0);
auto BinOpcode = BO->getOpcode();
if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) {
SelOpNo = 1;
Sel = BO->getOperand(1);
Expand Down Expand Up @@ -2526,6 +2529,7 @@ SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) {

SDLoc DL(Sel);
SDValue NewCT, NewCF;
EVT VT = BO->getValueType(0);

if (CanFoldNonConst) {
// If CBO is an opaque constant, we can't rely on getNode to constant fold.
Expand Down
6 changes: 4 additions & 2 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18040,8 +18040,10 @@ bool AArch64TargetLowering::shouldFoldConstantShiftPairToMask(
}

bool AArch64TargetLowering::shouldFoldSelectWithIdentityConstant(
unsigned BinOpcode, EVT VT) const {
return VT.isScalableVector() && isTypeLegal(VT);
unsigned BinOpcode, EVT VT, unsigned SelectOpcode, SDValue X,
SDValue Y) const {
return VT.isScalableVector() && isTypeLegal(VT) &&
SelectOpcode == ISD::VSELECT;
}

bool AArch64TargetLowering::shouldConvertConstantLoadToIntImm(const APInt &Imm,
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -786,8 +786,9 @@ class AArch64TargetLowering : public TargetLowering {
bool shouldFoldConstantShiftPairToMask(const SDNode *N,
CombineLevel Level) const override;

bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
EVT VT) const override;
bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT,
unsigned SelectOpcode, SDValue X,
SDValue Y) const override;

/// Returns true if it is beneficial to convert a load of a constant
/// to just the constant itself.
Expand Down
8 changes: 5 additions & 3 deletions llvm/lib/Target/ARM/ARMISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13960,9 +13960,11 @@ bool ARMTargetLowering::shouldFoldConstantShiftPairToMask(
return false;
}

bool ARMTargetLowering::shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
EVT VT) const {
return Subtarget->hasMVEIntegerOps() && isTypeLegal(VT);
bool ARMTargetLowering::shouldFoldSelectWithIdentityConstant(
unsigned BinOpcode, EVT VT, unsigned SelectOpcode, SDValue X,
SDValue Y) const {
return Subtarget->hasMVEIntegerOps() && isTypeLegal(VT) &&
SelectOpcode == ISD::VSELECT;
}

bool ARMTargetLowering::preferIncOfAddToSubOfNot(EVT VT) const {
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Target/ARM/ARMISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -758,8 +758,9 @@ class VectorType;
bool shouldFoldConstantShiftPairToMask(const SDNode *N,
CombineLevel Level) const override;

bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
EVT VT) const override;
bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT,
unsigned SelectOpcode, SDValue X,
SDValue Y) const override;

bool preferIncOfAddToSubOfNot(EVT VT) const override;

Expand Down
8 changes: 6 additions & 2 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2090,8 +2090,12 @@ bool RISCVTargetLowering::hasBitTest(SDValue X, SDValue Y) const {
return C && C->getAPIntValue().ule(10);
}

bool RISCVTargetLowering::shouldFoldSelectWithIdentityConstant(unsigned Opcode,
EVT VT) const {
bool RISCVTargetLowering::shouldFoldSelectWithIdentityConstant(
unsigned BinOpcode, EVT VT, unsigned SelectOpcode, SDValue X,
SDValue Y) const {
if (SelectOpcode != ISD::VSELECT)
return false;

// Only enable for rvv.
if (!VT.isVector() || !Subtarget.hasVInstructions())
return false;
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -585,8 +585,9 @@ class RISCVTargetLowering : public TargetLowering {
unsigned &NumIntermediates,
MVT &RegisterVT) const override;

bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
EVT VT) const override;
bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT,
unsigned SelectOpcode, SDValue X,
SDValue Y) const override;

/// Return true if the given shuffle mask can be codegen'd directly, or if it
/// should be stack expanded.
Expand Down
7 changes: 5 additions & 2 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35383,8 +35383,11 @@ bool X86TargetLowering::isNarrowingProfitable(SDNode *N, EVT SrcVT,
return !(SrcVT == MVT::i32 && DestVT == MVT::i16);
}

bool X86TargetLowering::shouldFoldSelectWithIdentityConstant(unsigned Opcode,
EVT VT) const {
bool X86TargetLowering::shouldFoldSelectWithIdentityConstant(
unsigned BinOpcode, EVT VT, unsigned SelectOpcode, SDValue X,
SDValue Y) const {
if (SelectOpcode != ISD::VSELECT)
return false;
// TODO: This is too general. There are cases where pre-AVX512 codegen would
// benefit. The transform may also be profitable for scalar code.
if (!Subtarget.hasAVX512())
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Target/X86/X86ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -1460,8 +1460,9 @@ namespace llvm {
/// from i32 to i16.
bool isNarrowingProfitable(SDNode *N, EVT SrcVT, EVT DestVT) const override;

bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
EVT VT) const override;
bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT,
unsigned SelectOpcode, SDValue X,
SDValue Y) const override;

/// Given an intrinsic, checks if on the target the intrinsic will need to map
/// to a MemIntrinsicNode (touches memory). If this is the case, it returns
Expand Down
Loading