Skip to content

Commit 758444c

Browse files
authored
[AMDGPU] Promote uniform ops to I32 in DAGISel (#106383)
Promote uniform binops, selects and setcc between 2 and 16 bits to 32 bits in DAGISel Solves #64591
1 parent 77af9d1 commit 758444c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+9357
-10722
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3306,7 +3306,7 @@ class TargetLoweringBase {
33063306
/// Return true if it's profitable to narrow operations of type SrcVT to
33073307
/// DestVT. e.g. on x86, it's profitable to narrow from i32 to i8 but not from
33083308
/// i32 to i16.
3309-
virtual bool isNarrowingProfitable(EVT SrcVT, EVT DestVT) const {
3309+
virtual bool isNarrowingProfitable(SDNode *N, EVT SrcVT, EVT DestVT) const {
33103310
return false;
33113311
}
33123312

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7132,7 +7132,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
71327132
if (N1C->getAPIntValue().countLeadingZeros() >= (BitWidth - SrcBitWidth) &&
71337133
TLI.isTruncateFree(VT, SrcVT) && TLI.isZExtFree(SrcVT, VT) &&
71347134
TLI.isTypeDesirableForOp(ISD::AND, SrcVT) &&
7135-
TLI.isNarrowingProfitable(VT, SrcVT))
7135+
TLI.isNarrowingProfitable(N, VT, SrcVT))
71367136
return DAG.getNode(ISD::ZERO_EXTEND, DL, VT,
71377137
DAG.getNode(ISD::AND, DL, SrcVT, N0Op0,
71387138
DAG.getZExtOrTrunc(N1, DL, SrcVT)));
@@ -14704,7 +14704,7 @@ SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
1470414704
// ShLeftAmt will indicate how much a narrowed load should be shifted left.
1470514705
unsigned ShLeftAmt = 0;
1470614706
if (ShAmt == 0 && N0.getOpcode() == ISD::SHL && N0.hasOneUse() &&
14707-
ExtVT == VT && TLI.isNarrowingProfitable(N0.getValueType(), VT)) {
14707+
ExtVT == VT && TLI.isNarrowingProfitable(N, N0.getValueType(), VT)) {
1470814708
if (ConstantSDNode *N01 = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
1470914709
ShLeftAmt = N01->getZExtValue();
1471014710
N0 = N0.getOperand(0);
@@ -15264,9 +15264,11 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
1526415264
}
1526515265

1526615266
// trunc (select c, a, b) -> select c, (trunc a), (trunc b)
15267-
if (N0.getOpcode() == ISD::SELECT && N0.hasOneUse()) {
15268-
if ((!LegalOperations || TLI.isOperationLegal(ISD::SELECT, SrcVT)) &&
15269-
TLI.isTruncateFree(SrcVT, VT)) {
15267+
if (N0.getOpcode() == ISD::SELECT && N0.hasOneUse() &&
15268+
TLI.isTruncateFree(SrcVT, VT)) {
15269+
if (!LegalOperations ||
15270+
(TLI.isOperationLegal(ISD::SELECT, SrcVT) &&
15271+
TLI.isNarrowingProfitable(N0.getNode(), SrcVT, VT))) {
1527015272
SDLoc SL(N0);
1527115273
SDValue Cond = N0.getOperand(0);
1527215274
SDValue TruncOp0 = DAG.getNode(ISD::TRUNCATE, SL, VT, N0.getOperand(1));
@@ -20207,10 +20209,9 @@ SDValue DAGCombiner::ReduceLoadOpStoreWidth(SDNode *N) {
2020720209
EVT NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW);
2020820210
// The narrowing should be profitable, the load/store operation should be
2020920211
// legal (or custom) and the store size should be equal to the NewVT width.
20210-
while (NewBW < BitWidth &&
20211-
(NewVT.getStoreSizeInBits() != NewBW ||
20212-
!TLI.isOperationLegalOrCustom(Opc, NewVT) ||
20213-
!TLI.isNarrowingProfitable(VT, NewVT))) {
20212+
while (NewBW < BitWidth && (NewVT.getStoreSizeInBits() != NewBW ||
20213+
!TLI.isOperationLegalOrCustom(Opc, NewVT) ||
20214+
!TLI.isNarrowingProfitable(N, VT, NewVT))) {
2021420215
NewBW = NextPowerOf2(NewBW);
2021520216
NewVT = EVT::getIntegerVT(*DAG.getContext(), NewBW);
2021620217
}

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1841,7 +1841,7 @@ bool TargetLowering::SimplifyDemandedBits(
18411841
for (unsigned SmallVTBits = llvm::bit_ceil(DemandedSize);
18421842
SmallVTBits < BitWidth; SmallVTBits = NextPowerOf2(SmallVTBits)) {
18431843
EVT SmallVT = EVT::getIntegerVT(*TLO.DAG.getContext(), SmallVTBits);
1844-
if (isNarrowingProfitable(VT, SmallVT) &&
1844+
if (isNarrowingProfitable(Op.getNode(), VT, SmallVT) &&
18451845
isTypeDesirableForOp(ISD::SHL, SmallVT) &&
18461846
isTruncateFree(VT, SmallVT) && isZExtFree(SmallVT, VT) &&
18471847
(!TLO.LegalOperations() || isOperationLegal(ISD::SHL, SmallVT))) {
@@ -1865,7 +1865,7 @@ bool TargetLowering::SimplifyDemandedBits(
18651865
if ((BitWidth % 2) == 0 && !VT.isVector() && ShAmt < HalfWidth &&
18661866
DemandedBits.countLeadingOnes() >= HalfWidth) {
18671867
EVT HalfVT = EVT::getIntegerVT(*TLO.DAG.getContext(), HalfWidth);
1868-
if (isNarrowingProfitable(VT, HalfVT) &&
1868+
if (isNarrowingProfitable(Op.getNode(), VT, HalfVT) &&
18691869
isTypeDesirableForOp(ISD::SHL, HalfVT) &&
18701870
isTruncateFree(VT, HalfVT) && isZExtFree(HalfVT, VT) &&
18711871
(!TLO.LegalOperations() || isOperationLegal(ISD::SHL, HalfVT))) {
@@ -1984,7 +1984,7 @@ bool TargetLowering::SimplifyDemandedBits(
19841984
if ((BitWidth % 2) == 0 && !VT.isVector()) {
19851985
APInt HiBits = APInt::getHighBitsSet(BitWidth, BitWidth / 2);
19861986
EVT HalfVT = EVT::getIntegerVT(*TLO.DAG.getContext(), BitWidth / 2);
1987-
if (isNarrowingProfitable(VT, HalfVT) &&
1987+
if (isNarrowingProfitable(Op.getNode(), VT, HalfVT) &&
19881988
isTypeDesirableForOp(ISD::SRL, HalfVT) &&
19891989
isTruncateFree(VT, HalfVT) && isZExtFree(HalfVT, VT) &&
19901990
(!TLO.LegalOperations() || isOperationLegal(ISD::SRL, HalfVT)) &&
@@ -4762,9 +4762,11 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
47624762
case ISD::SETULT:
47634763
case ISD::SETULE: {
47644764
EVT newVT = N0.getOperand(0).getValueType();
4765+
// FIXME: Should use isNarrowingProfitable.
47654766
if (DCI.isBeforeLegalizeOps() ||
47664767
(isOperationLegal(ISD::SETCC, newVT) &&
4767-
isCondCodeLegal(Cond, newVT.getSimpleVT()))) {
4768+
isCondCodeLegal(Cond, newVT.getSimpleVT()) &&
4769+
isTypeDesirableForOp(ISD::SETCC, newVT))) {
47684770
EVT NewSetCCVT = getSetCCResultType(Layout, *DAG.getContext(), newVT);
47694771
SDValue NewConst = DAG.getConstant(C1.trunc(InSize), dl, newVT);
47704772

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,14 +1022,45 @@ bool AMDGPUTargetLowering::isZExtFree(EVT Src, EVT Dest) const {
10221022
return Src == MVT::i32 && Dest == MVT::i64;
10231023
}
10241024

1025-
bool AMDGPUTargetLowering::isNarrowingProfitable(EVT SrcVT, EVT DestVT) const {
1025+
bool AMDGPUTargetLowering::isNarrowingProfitable(SDNode *N, EVT SrcVT,
1026+
EVT DestVT) const {
1027+
switch (N->getOpcode()) {
1028+
case ISD::ADD:
1029+
case ISD::SUB:
1030+
case ISD::SHL:
1031+
case ISD::SRL:
1032+
case ISD::SRA:
1033+
case ISD::AND:
1034+
case ISD::OR:
1035+
case ISD::XOR:
1036+
case ISD::MUL:
1037+
case ISD::SETCC:
1038+
case ISD::SELECT:
1039+
if (Subtarget->has16BitInsts() &&
1040+
(DestVT.isVector() ? !Subtarget->hasVOP3PInsts() : true)) {
1041+
// Don't narrow back down to i16 if promoted to i32 already.
1042+
if (!N->isDivergent() && DestVT.isInteger() &&
1043+
DestVT.getScalarSizeInBits() > 1 &&
1044+
DestVT.getScalarSizeInBits() <= 16 &&
1045+
SrcVT.getScalarSizeInBits() > 16) {
1046+
return false;
1047+
}
1048+
}
1049+
return true;
1050+
default:
1051+
break;
1052+
}
1053+
10261054
// There aren't really 64-bit registers, but pairs of 32-bit ones and only a
10271055
// limited number of native 64-bit operations. Shrinking an operation to fit
10281056
// in a single 32-bit register should always be helpful. As currently used,
10291057
// this is much less general than the name suggests, and is only used in
10301058
// places trying to reduce the sizes of loads. Shrinking loads to < 32-bits is
10311059
// not profitable, and may actually be harmful.
1032-
return SrcVT.getSizeInBits() > 32 && DestVT.getSizeInBits() == 32;
1060+
if (isa<LoadSDNode>(N))
1061+
return SrcVT.getSizeInBits() > 32 && DestVT.getSizeInBits() == 32;
1062+
1063+
return true;
10331064
}
10341065

10351066
bool AMDGPUTargetLowering::isDesirableToCommuteWithShift(

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ class AMDGPUTargetLowering : public TargetLowering {
201201
NegatibleCost &Cost,
202202
unsigned Depth) const override;
203203

204-
bool isNarrowingProfitable(EVT SrcVT, EVT DestVT) const override;
204+
bool isNarrowingProfitable(SDNode *N, EVT SrcVT, EVT DestVT) const override;
205205

206206
bool isDesirableToCommuteWithShift(const SDNode *N,
207207
CombineLevel Level) const override;

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 120 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,7 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
894894
ISD::UADDO_CARRY,
895895
ISD::SUB,
896896
ISD::USUBO_CARRY,
897+
ISD::MUL,
897898
ISD::FADD,
898899
ISD::FSUB,
899900
ISD::FDIV,
@@ -909,9 +910,17 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
909910
ISD::UMIN,
910911
ISD::UMAX,
911912
ISD::SETCC,
913+
ISD::SELECT,
914+
ISD::SMIN,
915+
ISD::SMAX,
916+
ISD::UMIN,
917+
ISD::UMAX,
912918
ISD::AND,
913919
ISD::OR,
914920
ISD::XOR,
921+
ISD::SHL,
922+
ISD::SRL,
923+
ISD::SRA,
915924
ISD::FSHR,
916925
ISD::SINT_TO_FP,
917926
ISD::UINT_TO_FP,
@@ -1942,13 +1951,6 @@ bool SITargetLowering::isTypeDesirableForOp(unsigned Op, EVT VT) const {
19421951
switch (Op) {
19431952
case ISD::LOAD:
19441953
case ISD::STORE:
1945-
1946-
// These operations are done with 32-bit instructions anyway.
1947-
case ISD::AND:
1948-
case ISD::OR:
1949-
case ISD::XOR:
1950-
case ISD::SELECT:
1951-
// TODO: Extensions?
19521954
return true;
19531955
default:
19541956
return false;
@@ -6731,6 +6733,93 @@ SDValue SITargetLowering::lowerFLDEXP(SDValue Op, SelectionDAG &DAG) const {
67316733
return DAG.getNode(ISD::FLDEXP, DL, VT, Op.getOperand(0), TruncExp);
67326734
}
67336735

6736+
static unsigned getExtOpcodeForPromotedOp(SDValue Op) {
6737+
switch (Op->getOpcode()) {
6738+
case ISD::SRA:
6739+
case ISD::SMIN:
6740+
case ISD::SMAX:
6741+
return ISD::SIGN_EXTEND;
6742+
case ISD::SRL:
6743+
case ISD::UMIN:
6744+
case ISD::UMAX:
6745+
return ISD::ZERO_EXTEND;
6746+
case ISD::ADD:
6747+
case ISD::SUB:
6748+
case ISD::AND:
6749+
case ISD::OR:
6750+
case ISD::XOR:
6751+
case ISD::SHL:
6752+
case ISD::SELECT:
6753+
case ISD::MUL:
6754+
// operation result won't be influenced by garbage high bits.
6755+
// TODO: are all of those cases correct, and are there more?
6756+
return ISD::ANY_EXTEND;
6757+
case ISD::SETCC: {
6758+
ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
6759+
return ISD::isSignedIntSetCC(CC) ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
6760+
}
6761+
default:
6762+
llvm_unreachable("unexpected opcode!");
6763+
}
6764+
}
6765+
6766+
SDValue SITargetLowering::promoteUniformOpToI32(SDValue Op,
6767+
DAGCombinerInfo &DCI) const {
6768+
const unsigned Opc = Op.getOpcode();
6769+
assert(Opc == ISD::ADD || Opc == ISD::SUB || Opc == ISD::SHL ||
6770+
Opc == ISD::SRL || Opc == ISD::SRA || Opc == ISD::AND ||
6771+
Opc == ISD::OR || Opc == ISD::XOR || Opc == ISD::MUL ||
6772+
Opc == ISD::SETCC || Opc == ISD::SELECT || Opc == ISD::SMIN ||
6773+
Opc == ISD::SMAX || Opc == ISD::UMIN || Opc == ISD::UMAX);
6774+
6775+
EVT OpTy = (Opc != ISD::SETCC) ? Op.getValueType()
6776+
: Op->getOperand(0).getValueType();
6777+
auto ExtTy = OpTy.changeElementType(MVT::i32);
6778+
6779+
if (DCI.isBeforeLegalizeOps() ||
6780+
isNarrowingProfitable(Op.getNode(), ExtTy, OpTy))
6781+
return SDValue();
6782+
6783+
auto &DAG = DCI.DAG;
6784+
6785+
SDLoc DL(Op);
6786+
SDValue LHS;
6787+
SDValue RHS;
6788+
if (Opc == ISD::SELECT) {
6789+
LHS = Op->getOperand(1);
6790+
RHS = Op->getOperand(2);
6791+
} else {
6792+
LHS = Op->getOperand(0);
6793+
RHS = Op->getOperand(1);
6794+
}
6795+
6796+
const unsigned ExtOp = getExtOpcodeForPromotedOp(Op);
6797+
LHS = DAG.getNode(ExtOp, DL, ExtTy, {LHS});
6798+
6799+
// Special case: for shifts, the RHS always needs a zext.
6800+
if (Op.getOpcode() == ISD::SRA || Op.getOpcode() == ISD::SRL ||
6801+
Op.getOpcode() == ISD::SRA)
6802+
RHS = DAG.getNode(ISD::ZERO_EXTEND, DL, ExtTy, {RHS});
6803+
else
6804+
RHS = DAG.getNode(ExtOp, DL, ExtTy, {RHS});
6805+
6806+
// setcc always return i1/i1 vec so no need to truncate after.
6807+
if (Opc == ISD::SETCC) {
6808+
ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
6809+
return DAG.getSetCC(DL, Op.getValueType(), LHS, RHS, CC);
6810+
}
6811+
6812+
// For other ops, we extend the operation's return type as well so we need to
6813+
// truncate back to the original type.
6814+
SDValue NewVal;
6815+
if (Opc == ISD::SELECT)
6816+
NewVal = DAG.getNode(ISD::SELECT, DL, ExtTy, {Op->getOperand(0), LHS, RHS});
6817+
else
6818+
NewVal = DAG.getNode(Opc, DL, ExtTy, {LHS, RHS});
6819+
6820+
return DAG.getZExtOrTrunc(NewVal, DL, OpTy);
6821+
}
6822+
67346823
// Custom lowering for vector multiplications and s_mul_u64.
67356824
SDValue SITargetLowering::lowerMUL(SDValue Op, SelectionDAG &DAG) const {
67366825
EVT VT = Op.getValueType();
@@ -14623,8 +14712,32 @@ SDValue SITargetLowering::performClampCombine(SDNode *N,
1462314712

1462414713
SDValue SITargetLowering::PerformDAGCombine(SDNode *N,
1462514714
DAGCombinerInfo &DCI) const {
14715+
switch (N->getOpcode()) {
14716+
case ISD::ADD:
14717+
case ISD::SUB:
14718+
case ISD::SHL:
14719+
case ISD::SRL:
14720+
case ISD::SRA:
14721+
case ISD::AND:
14722+
case ISD::OR:
14723+
case ISD::XOR:
14724+
case ISD::MUL:
14725+
case ISD::SETCC:
14726+
case ISD::SELECT:
14727+
case ISD::SMIN:
14728+
case ISD::SMAX:
14729+
case ISD::UMIN:
14730+
case ISD::UMAX:
14731+
if (auto Res = promoteUniformOpToI32(SDValue(N, 0), DCI))
14732+
return Res;
14733+
break;
14734+
default:
14735+
break;
14736+
}
14737+
1462614738
if (getTargetMachine().getOptLevel() == CodeGenOptLevel::None)
1462714739
return SDValue();
14740+
1462814741
switch (N->getOpcode()) {
1462914742
case ISD::ADD:
1463014743
return performAddCombine(N, DCI);

llvm/lib/Target/AMDGPU/SIISelLowering.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ class SITargetLowering final : public AMDGPUTargetLowering {
147147
SDValue lowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const;
148148
SDValue lowerFMINNUM_FMAXNUM(SDValue Op, SelectionDAG &DAG) const;
149149
SDValue lowerFLDEXP(SDValue Op, SelectionDAG &DAG) const;
150+
SDValue promoteUniformOpToI32(SDValue Op, DAGCombinerInfo &DCI) const;
150151
SDValue lowerMUL(SDValue Op, SelectionDAG &DAG) const;
151152
SDValue lowerXMULO(SDValue Op, SelectionDAG &DAG) const;
152153
SDValue lowerXMUL_LOHI(SDValue Op, SelectionDAG &DAG) const;
@@ -462,7 +463,6 @@ class SITargetLowering final : public AMDGPUTargetLowering {
462463
SDValue splitBinaryVectorOp(SDValue Op, SelectionDAG &DAG) const;
463464
SDValue splitTernaryVectorOp(SDValue Op, SelectionDAG &DAG) const;
464465
SDValue LowerOperation(SDValue Op, SelectionDAG &DAG) const override;
465-
466466
void ReplaceNodeResults(SDNode *N, SmallVectorImpl<SDValue> &Results,
467467
SelectionDAG &DAG) const override;
468468

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34834,7 +34834,8 @@ bool X86TargetLowering::isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
3483434834
return false;
3483534835
}
3483634836

34837-
bool X86TargetLowering::isNarrowingProfitable(EVT SrcVT, EVT DestVT) const {
34837+
bool X86TargetLowering::isNarrowingProfitable(SDNode *N, EVT SrcVT,
34838+
EVT DestVT) const {
3483834839
// i16 instructions are longer (0x66 prefix) and potentially slower.
3483934840
return !(SrcVT == MVT::i32 && DestVT == MVT::i16);
3484034841
}

llvm/lib/Target/X86/X86ISelLowering.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1451,7 +1451,7 @@ namespace llvm {
14511451
/// Return true if it's profitable to narrow operations of type SrcVT to
14521452
/// DestVT. e.g. on x86, it's profitable to narrow from i32 to i8 but not
14531453
/// from i32 to i16.
1454-
bool isNarrowingProfitable(EVT SrcVT, EVT DestVT) const override;
1454+
bool isNarrowingProfitable(SDNode *N, EVT SrcVT, EVT DestVT) const override;
14551455

14561456
bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,
14571457
EVT VT) const override;

0 commit comments

Comments
 (0)