Skip to content

[AMDGPU] Promote uniform ops to I32 in DAGISel #106383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -3306,7 +3306,7 @@ class TargetLoweringBase {
/// Return true if it's profitable to narrow operations of type SrcVT to
/// DestVT. e.g. on x86, it's profitable to narrow from i32 to i8 but not from
/// i32 to i16.
virtual bool isNarrowingProfitable(EVT SrcVT, EVT DestVT) const {
virtual bool isNarrowingProfitable(SDNode *N, EVT SrcVT, EVT DestVT) const {
return false;
}

Expand Down
19 changes: 10 additions & 9 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7050,7 +7050,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
if (N1C->getAPIntValue().countLeadingZeros() >= (BitWidth - SrcBitWidth) &&
TLI.isTruncateFree(VT, SrcVT) && TLI.isZExtFree(SrcVT, VT) &&
TLI.isTypeDesirableForOp(ISD::AND, SrcVT) &&
TLI.isNarrowingProfitable(VT, SrcVT))
TLI.isNarrowingProfitable(N, VT, SrcVT))
return DAG.getNode(ISD::ZERO_EXTEND, DL, VT,
DAG.getNode(ISD::AND, DL, SrcVT, N0Op0,
DAG.getZExtOrTrunc(N1, DL, SrcVT)));
Expand Down Expand Up @@ -14622,7 +14622,7 @@ SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
// ShLeftAmt will indicate how much a narrowed load should be shifted left.
unsigned ShLeftAmt = 0;
if (ShAmt == 0 && N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
ExtVT == VT && TLI.isNarrowingProfitable(N0.getValueType(), VT)) {
ExtVT == VT && TLI.isNarrowingProfitable(N, N0.getValueType(), VT)) {
if (ConstantSDNode *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
ShLeftAmt = N01->getZExtValue();
N0 = N0.getOperand(0);
Expand Down Expand Up @@ -15166,9 +15166,11 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
}

// trunc (select c, a, b) -> select c, (trunc a), (trunc b)
if (N0.getOpcode() == ISD::SELECT && N0.hasOneUse()) {
if ((!LegalOperations || TLI.isOperationLegal(ISD::SELECT, SrcVT)) &&
TLI.isTruncateFree(SrcVT, VT)) {
if (N0.getOpcode() == ISD::SELECT && N0.hasOneUse() &&
TLI.isTruncateFree(SrcVT, VT)) {
if (!LegalOperations ||
(TLI.isOperationLegal(ISD::SELECT, SrcVT) &&
TLI.isNarrowingProfitable(N0.getNode(), SrcVT, VT))) {
SDLoc SL(N0);
SDValue Cond = N0.getOperand(0);
SDValue TruncOp0 = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(1));
Expand Down Expand Up @@ -20109,10 +20111,9 @@ SDValue DAGCombiner::ReduceLoadOpStoreWidth(SDNode *N) {
EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW);
// The narrowing should be profitable, the load/store operation should be
// legal (or custom) and the store size should be equal to the NewVT width.
while (NewBW < BitWidth &&
(NewVT.getStoreSizeInBits() != NewBW ||
!TLI.isOperationLegalOrCustom(Opc, NewVT) ||
!TLI.isNarrowingProfitable(VT, NewVT))) {
while (NewBW < BitWidth && (NewVT.getStoreSizeInBits() != NewBW ||
!TLI.isOperationLegalOrCustom(Opc, NewVT) ||
!TLI.isNarrowingProfitable(N, VT, NewVT))) {
NewBW = NextPowerOf2(NewBW);
NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW);
}
Expand Down
10 changes: 6 additions & 4 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1841,7 +1841,7 @@ bool TargetLowering::SimplifyDemandedBits(
for (unsigned SmallVTBits = llvm::bit_ceil(DemandedSize);
SmallVTBits < BitWidth; SmallVTBits = NextPowerOf2(SmallVTBits)) {
EVT SmallVT = EVT::getIntegerVT(*TLO.DAG.getContext(), SmallVTBits);
if (isNarrowingProfitable(VT, SmallVT) &&
if (isNarrowingProfitable(Op.getNode(), VT, SmallVT) &&
isTypeDesirableForOp(ISD::SHL, SmallVT) &&
isTruncateFree(VT, SmallVT) && isZExtFree(SmallVT, VT) &&
(!TLO.LegalOperations() || isOperationLegal(ISD::SHL, SmallVT))) {
Expand All @@ -1865,7 +1865,7 @@ bool TargetLowering::SimplifyDemandedBits(
if ((BitWidth % 2) == 0 && !VT.isVector() && ShAmt < HalfWidth &&
DemandedBits.countLeadingOnes() >= HalfWidth) {
EVT HalfVT = EVT::getIntegerVT(*TLO.DAG.getContext(), HalfWidth);
if (isNarrowingProfitable(VT, HalfVT) &&
if (isNarrowingProfitable(Op.getNode(), VT, HalfVT) &&
isTypeDesirableForOp(ISD::SHL, HalfVT) &&
isTruncateFree(VT, HalfVT) && isZExtFree(HalfVT, VT) &&
(!TLO.LegalOperations() || isOperationLegal(ISD::SHL, HalfVT))) {
Expand Down Expand Up @@ -1984,7 +1984,7 @@ bool TargetLowering::SimplifyDemandedBits(
if ((BitWidth % 2) == 0 && !VT.isVector()) {
APInt HiBits = APInt::getHighBitsSet(BitWidth, BitWidth / 2);
EVT HalfVT = EVT::getIntegerVT(*TLO.DAG.getContext(), BitWidth / 2);
if (isNarrowingProfitable(VT, HalfVT) &&
if (isNarrowingProfitable(Op.getNode(), VT, HalfVT) &&
isTypeDesirableForOp(ISD::SRL, HalfVT) &&
isTruncateFree(VT, HalfVT) && isZExtFree(HalfVT, VT) &&
(!TLO.LegalOperations() || isOperationLegal(ISD::SRL, HalfVT)) &&
Expand Down Expand Up @@ -4762,9 +4762,11 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
case ISD::SETULT:
case ISD::SETULE: {
EVT newVT = N0.getOperand(0).getValueType();
// FIXME: Should use isNarrowingProfitable.
if (DCI.isBeforeLegalizeOps() ||
(isOperationLegal(ISD::SETCC, newVT) &&
isCondCodeLegal(Cond, newVT.getSimpleVT()))) {
isCondCodeLegal(Cond, newVT.getSimpleVT()) &&
isTypeDesirableForOp(ISD::SETCC, newVT))) {
EVT NewSetCCVT = getSetCCResultType(Layout, *DAG.getContext(), newVT);
SDValue NewConst = DAG.getConstant(C1.trunc(InSize), dl, newVT);

Expand Down
35 changes: 33 additions & 2 deletions llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1022,14 +1022,45 @@ bool AMDGPUTargetLowering::isZExtFree(EVT Src, EVT Dest) const {
return Src == MVT::i32 && Dest == MVT::i64;
}

bool AMDGPUTargetLowering::isNarrowingProfitable(EVT SrcVT, EVT DestVT) const {
bool AMDGPUTargetLowering::isNarrowingProfitable(SDNode *N, EVT SrcVT,
EVT DestVT) const {
switch (N->getOpcode()) {
case ISD::ADD:
case ISD::SUB:
case ISD::SHL:
case ISD::SRL:
case ISD::SRA:
case ISD::AND:
case ISD::OR:
case ISD::XOR:
case ISD::MUL:
case ISD::SETCC:
case ISD::SELECT:
if (Subtarget->has16BitInsts() &&
(DestVT.isVector() ? !Subtarget->hasVOP3PInsts() : true)) {
// Don't narrow back down to i16 if promoted to i32 already.
if (!N->isDivergent() && DestVT.isInteger() &&
DestVT.getScalarSizeInBits() > 1 &&
DestVT.getScalarSizeInBits() <= 16 &&
SrcVT.getScalarSizeInBits() > 16) {
return false;
}
}
return true;
default:
break;
}

// There aren't really 64-bit registers, but pairs of 32-bit ones and only a
// limited number of native 64-bit operations. Shrinking an operation to fit
// in a single 32-bit register should always be helpful. As currently used,
// this is much less general than the name suggests, and is only used in
// places trying to reduce the sizes of loads. Shrinking loads to < 32-bits is
// not profitable, and may actually be harmful.
return SrcVT.getSizeInBits() > 32 && DestVT.getSizeInBits() == 32;
if (isa<LoadSDNode>(N))
return SrcVT.getSizeInBits() > 32 && DestVT.getSizeInBits() == 32;

return true;
}

bool AMDGPUTargetLowering::isDesirableToCommuteWithShift(
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ class AMDGPUTargetLowering : public TargetLowering {
NegatibleCost &Cost,
unsigned Depth) const override;

bool isNarrowingProfitable(EVT SrcVT, EVT DestVT) const override;
bool isNarrowingProfitable(SDNode *N, EVT SrcVT, EVT DestVT) const override;

bool isDesirableToCommuteWithShift(const SDNode *N,
CombineLevel Level) const override;
Expand Down
127 changes: 120 additions & 7 deletions llvm/lib/Target/AMDGPU/SIISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,7 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
ISD::UADDO_CARRY,
ISD::SUB,
ISD::USUBO_CARRY,
ISD::MUL,
ISD::FADD,
ISD::FSUB,
ISD::FDIV,
Expand All @@ -909,9 +910,17 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
ISD::UMIN,
ISD::UMAX,
ISD::SETCC,
ISD::SELECT,
ISD::SMIN,
ISD::SMAX,
ISD::UMIN,
ISD::UMAX,
ISD::AND,
ISD::OR,
ISD::XOR,
ISD::SHL,
ISD::SRL,
ISD::SRA,
ISD::FSHR,
ISD::SINT_TO_FP,
ISD::UINT_TO_FP,
Expand Down Expand Up @@ -1948,13 +1957,6 @@ bool SITargetLowering::isTypeDesirableForOp(unsigned Op, EVT VT) const {
switch (Op) {
case ISD::LOAD:
case ISD::STORE:

// These operations are done with 32-bit instructions anyway.
case ISD::AND:
case ISD::OR:
case ISD::XOR:
case ISD::SELECT:
// TODO: Extensions?
return true;
default:
return false;
Expand Down Expand Up @@ -6733,6 +6735,93 @@ SDValue SITargetLowering::lowerFLDEXP(SDValue Op, SelectionDAG &DAG) const {
return DAG.getNode(ISD::FLDEXP, DL, VT, Op.getOperand(0), TruncExp);
}

static unsigned getExtOpcodeForPromotedOp(SDValue Op) {
switch (Op->getOpcode()) {
case ISD::SRA:
case ISD::SMIN:
case ISD::SMAX:
return ISD::SIGN_EXTEND;
case ISD::SRL:
case ISD::UMIN:
case ISD::UMAX:
return ISD::ZERO_EXTEND;
case ISD::ADD:
case ISD::SUB:
case ISD::AND:
case ISD::OR:
case ISD::XOR:
case ISD::SHL:
case ISD::SELECT:
case ISD::MUL:
// operation result won't be influenced by garbage high bits.
// TODO: are all of those cases correct, and are there more?
return ISD::ANY_EXTEND;
case ISD::SETCC: {
ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
return ISD::isSignedIntSetCC(CC) ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
}
default:
llvm_unreachable("unexpected opcode!");
}
}

SDValue SITargetLowering::promoteUniformOpToI32(SDValue Op,
DAGCombinerInfo &DCI) const {
const unsigned Opc = Op.getOpcode();
assert(Opc == ISD::ADD || Opc == ISD::SUB || Opc == ISD::SHL ||
Opc == ISD::SRL || Opc == ISD::SRA || Opc == ISD::AND ||
Opc == ISD::OR || Opc == ISD::XOR || Opc == ISD::MUL ||
Opc == ISD::SETCC || Opc == ISD::SELECT || Opc == ISD::SMIN ||
Opc == ISD::SMAX || Opc == ISD::UMIN || Opc == ISD::UMAX);

EVT OpTy = (Opc != ISD::SETCC) ? Op.getValueType()
: Op->getOperand(0).getValueType();
auto ExtTy = OpTy.changeElementType(MVT::i32);

if (DCI.isBeforeLegalizeOps() ||
isNarrowingProfitable(Op.getNode(), ExtTy, OpTy))
return SDValue();

auto &DAG = DCI.DAG;

SDLoc DL(Op);
SDValue LHS;
SDValue RHS;
if (Opc == ISD::SELECT) {
LHS = Op->getOperand(1);
RHS = Op->getOperand(2);
} else {
LHS = Op->getOperand(0);
RHS = Op->getOperand(1);
}

const unsigned ExtOp = getExtOpcodeForPromotedOp(Op);
LHS = DAG.getNode(ExtOp, DL, ExtTy, {LHS});

// Special case: for shifts, the RHS always needs a zext.
if (Op.getOpcode() == ISD::SRA || Op.getOpcode() == ISD::SRL ||
Op.getOpcode() == ISD::SRA)
Comment on lines +6802 to +6803
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (Op.getOpcode() == ISD::SRA || Op.getOpcode() == ISD::SRL ||
Op.getOpcode() == ISD::SRA)
if (Op.getOpcode() == ISD::SHL || Op.getOpcode() == ISD::SRL ||
Op.getOpcode() == ISD::SRA)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Pierre-vh ping - this looks like it was a simple typo?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops sorry, I'll fix it right now

RHS = DAG.getNode(ISD::ZERO_EXTEND, DL, ExtTy, {RHS});
else
RHS = DAG.getNode(ExtOp, DL, ExtTy, {RHS});

// setcc always return i1/i1 vec so no need to truncate after.
if (Opc == ISD::SETCC) {
ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
return DAG.getSetCC(DL, Op.getValueType(), LHS, RHS, CC);
}

// For other ops, we extend the operation's return type as well so we need to
// truncate back to the original type.
SDValue NewVal;
if (Opc == ISD::SELECT)
NewVal = DAG.getNode(ISD::SELECT, DL, ExtTy, {Op->getOperand(0), LHS, RHS});
else
NewVal = DAG.getNode(Opc, DL, ExtTy, {LHS, RHS});

return DAG.getZExtOrTrunc(NewVal, DL, OpTy);
}

// Custom lowering for vector multiplications and s_mul_u64.
SDValue SITargetLowering::lowerMUL(SDValue Op, SelectionDAG &DAG) const {
EVT VT = Op.getValueType();
Expand Down Expand Up @@ -14687,8 +14776,32 @@ SDValue SITargetLowering::performClampCombine(SDNode *N,

SDValue SITargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
switch (N->getOpcode()) {
case ISD::ADD:
case ISD::SUB:
case ISD::SHL:
case ISD::SRL:
case ISD::SRA:
case ISD::AND:
case ISD::OR:
case ISD::XOR:
case ISD::MUL:
case ISD::SETCC:
case ISD::SELECT:
case ISD::SMIN:
case ISD::SMAX:
case ISD::UMIN:
case ISD::UMAX:
if (auto Res = promoteUniformOpToI32(SDValue(N, 0), DCI))
return Res;
break;
default:
break;
}

if (getTargetMachine().getOptLevel() == CodeGenOptLevel::None)
return SDValue();

switch (N->getOpcode()) {
case ISD::ADD:
return performAddCombine(N, DCI);
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AMDGPU/SIISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class SITargetLowering final : public AMDGPUTargetLowering {
SDValue lowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFMINNUM_FMAXNUM(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFLDEXP(SDValue Op, SelectionDAG &DAG) const;
SDValue promoteUniformOpToI32(SDValue Op, DAGCombinerInfo &DCI) const;
SDValue lowerMUL(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerXMULO(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerXMUL_LOHI(SDValue Op, SelectionDAG &DAG) const;
Expand Down Expand Up @@ -463,7 +464,6 @@ class SITargetLowering final : public AMDGPUTargetLowering {
SDValue splitBinaryVectorOp(SDValue Op, SelectionDAG &DAG) const;
SDValue splitTernaryVectorOp(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerOperation(SDValue Op, SelectionDAG &DAG) const override;

void ReplaceNodeResults(SDNode *N, SmallVectorImpl<SDValue> &Results,
SelectionDAG &DAG) const override;

Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34533,7 +34533,8 @@ bool X86TargetLowering::isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
return false;
}

bool X86TargetLowering::isNarrowingProfitable(EVT SrcVT, EVT DestVT) const {
bool X86TargetLowering::isNarrowingProfitable(SDNode *N, EVT SrcVT,
EVT DestVT) const {
// i16 instructions are longer (0x66 prefix) and potentially slower.
return !(SrcVT == MVT::i32 && DestVT == MVT::i16);
}
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/X86/X86ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -1429,7 +1429,7 @@ namespace llvm {
/// Return true if it's profitable to narrow operations of type SrcVT to
/// DestVT. e.g. on x86, it's profitable to narrow from i32 to i8 but not
/// from i32 to i16.
bool isNarrowingProfitable(EVT SrcVT, EVT DestVT) const override;
bool isNarrowingProfitable(SDNode *N, EVT SrcVT, EVT DestVT) const override;

bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
EVT VT) const override;
Expand Down
Loading
Loading