Skip to content

Commit 33adfc8

Browse files
committed
[WIP][DAG] Add legalization handling for AVGCEIL/AVGFLOOR nodes
Still WIP, but I wanted to get some visibility to other teams. Always match AVG patterns pre-legalization, and use TargetLowering::expandAVG to expand again during legalization. I've removed the X86 custom AVGCEILU pattern detection and replaced with combines to try and convert other AVG nodes to AVGCEILU
1 parent ce67fcf commit 33adfc8

21 files changed

+3550
-9991
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5332,6 +5332,11 @@ class TargetLowering : public TargetLoweringBase {
53325332
/// \returns The expansion result or SDValue() if it fails.
53335333
SDValue expandABD(SDNode *N, SelectionDAG &DAG) const;
53345334

5335+
/// Expand vector/scalar AVGCEILS/AVGCEILU/AVGFLOORS/AVGFLOORU nodes.
5336+
/// \param N Node to expand
5337+
/// \returns The expansion result or SDValue() if it fails.
5338+
SDValue expandAVG(SDNode *N, SelectionDAG &DAG) const;
5339+
53355340
/// Expand BSWAP nodes. Expands scalar/vector BSWAP nodes with i16/i32/i64
53365341
/// scalar types. Returns SDValue() if expand fails.
53375342
/// \param N Node to expand

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2578,13 +2578,13 @@ SDValue DAGCombiner::foldSubToAvg(SDNode *N, const SDLoc &DL) {
25782578
EVT VT = N0.getValueType();
25792579
SDValue A, B;
25802580

2581-
if (hasOperation(ISD::AVGCEILU, VT) &&
2581+
if ((!LegalOperations || hasOperation(ISD::AVGCEILU, VT)) &&
25822582
sd_match(N, m_Sub(m_Or(m_Value(A), m_Value(B)),
25832583
m_Srl(m_Xor(m_Deferred(A), m_Deferred(B)),
25842584
m_SpecificInt(1))))) {
25852585
return DAG.getNode(ISD::AVGCEILU, DL, VT, A, B);
25862586
}
2587-
if (hasOperation(ISD::AVGCEILS, VT) &&
2587+
if ((!LegalOperations || hasOperation(ISD::AVGCEILS, VT)) &&
25882588
sd_match(N, m_Sub(m_Or(m_Value(A), m_Value(B)),
25892589
m_Sra(m_Xor(m_Deferred(A), m_Deferred(B)),
25902590
m_SpecificInt(1))))) {
@@ -2950,13 +2950,13 @@ SDValue DAGCombiner::foldAddToAvg(SDNode *N, const SDLoc &DL) {
29502950
EVT VT = N0.getValueType();
29512951
SDValue A, B;
29522952

2953-
if (hasOperation(ISD::AVGFLOORU, VT) &&
2953+
if ((!LegalOperations || hasOperation(ISD::AVGFLOORU, VT)) &&
29542954
sd_match(N, m_Add(m_And(m_Value(A), m_Value(B)),
29552955
m_Srl(m_Xor(m_Deferred(A), m_Deferred(B)),
29562956
m_SpecificInt(1))))) {
29572957
return DAG.getNode(ISD::AVGFLOORU, DL, VT, A, B);
29582958
}
2959-
if (hasOperation(ISD::AVGFLOORS, VT) &&
2959+
if ((!LegalOperations || hasOperation(ISD::AVGFLOORS, VT)) &&
29602960
sd_match(N, m_Add(m_And(m_Value(A), m_Value(B)),
29612961
m_Sra(m_Xor(m_Deferred(A), m_Deferred(B)),
29622962
m_SpecificInt(1))))) {
@@ -5234,6 +5234,21 @@ SDValue DAGCombiner::visitAVG(SDNode *N) {
52345234
if (N0 == N1 && Level >= AfterLegalizeTypes)
52355235
return N0;
52365236

5237+
// Fold avgflooru(x,y) -> avgceilu(x,y-1) iff y != 0
5238+
// Fold avgflooru(x,y) -> avgceilu(x-1,y) iff x != 0
5239+
// Check if avgflooru isn't legal/custom but avgceilu is.
5240+
if (Opcode == ISD::AVGFLOORU && !hasOperation(ISD::AVGFLOORU, VT) &&
5241+
(!LegalOperations || hasOperation(ISD::AVGCEILU, VT))) {
5242+
if (DAG.isKnownNeverZero(N0))
5243+
return DAG.getNode(
5244+
ISD::AVGCEILU, DL, VT, N1,
5245+
DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getAllOnesConstant(DL, VT)));
5246+
if (DAG.isKnownNeverZero(N1))
5247+
return DAG.getNode(
5248+
ISD::AVGCEILU, DL, VT, N0,
5249+
DAG.getNode(ISD::ADD, DL, VT, N1, DAG.getAllOnesConstant(DL, VT)));
5250+
}
5251+
52375252
// TODO If we use avg for scalars anywhere, we can add (avgfl x, 0) -> x >> 1
52385253

52395254
return SDValue();

llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3047,6 +3047,13 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
30473047
if ((Tmp1 = TLI.expandABD(Node, DAG)))
30483048
Results.push_back(Tmp1);
30493049
break;
3050+
case ISD::AVGCEILS:
3051+
case ISD::AVGCEILU:
3052+
case ISD::AVGFLOORS:
3053+
case ISD::AVGFLOORU:
3054+
if ((Tmp1 = TLI.expandAVG(Node, DAG)))
3055+
Results.push_back(Tmp1);
3056+
break;
30503057
case ISD::CTPOP:
30513058
if ((Tmp1 = TLI.expandCTPOP(Node, DAG)))
30523059
Results.push_back(Tmp1);

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,13 +188,17 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
188188
case ISD::VP_SUB:
189189
case ISD::VP_MUL: Res = PromoteIntRes_SimpleIntBinOp(N); break;
190190

191+
case ISD::AVGCEILS:
192+
case ISD::AVGFLOORS:
191193
case ISD::VP_SMIN:
192194
case ISD::VP_SMAX:
193195
case ISD::SDIV:
194196
case ISD::SREM:
195197
case ISD::VP_SDIV:
196198
case ISD::VP_SREM: Res = PromoteIntRes_SExtIntBinOp(N); break;
197199

200+
case ISD::AVGCEILU:
201+
case ISD::AVGFLOORU:
198202
case ISD::VP_UMIN:
199203
case ISD::VP_UMAX:
200204
case ISD::UDIV:
@@ -2775,6 +2779,11 @@ void DAGTypeLegalizer::ExpandIntegerResult(SDNode *N, unsigned ResNo) {
27752779
case ISD::SSHLSAT:
27762780
case ISD::USHLSAT: ExpandIntRes_SHLSAT(N, Lo, Hi); break;
27772781

2782+
case ISD::AVGCEILS:
2783+
case ISD::AVGCEILU:
2784+
case ISD::AVGFLOORS:
2785+
case ISD::AVGFLOORU: ExpandIntRes_AVG(N, Lo, Hi); break;
2786+
27782787
case ISD::SMULFIX:
27792788
case ISD::SMULFIXSAT:
27802789
case ISD::UMULFIX:
@@ -4077,6 +4086,11 @@ void DAGTypeLegalizer::ExpandIntRes_READCOUNTER(SDNode *N, SDValue &Lo,
40774086
ReplaceValueWith(SDValue(N, 1), R.getValue(2));
40784087
}
40794088

4089+
void DAGTypeLegalizer::ExpandIntRes_AVG(SDNode *N, SDValue &Lo, SDValue &Hi) {
4090+
SDValue Result = TLI.expandAVG(N, DAG);
4091+
SplitInteger(Result, Lo, Hi);
4092+
}
4093+
40804094
void DAGTypeLegalizer::ExpandIntRes_ADDSUBSAT(SDNode *N, SDValue &Lo,
40814095
SDValue &Hi) {
40824096
SDValue Result = TLI.expandAddSubSat(N, DAG);

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
460460
void ExpandIntRes_SADDSUBO (SDNode *N, SDValue &Lo, SDValue &Hi);
461461
void ExpandIntRes_UADDSUBO (SDNode *N, SDValue &Lo, SDValue &Hi);
462462
void ExpandIntRes_XMULO (SDNode *N, SDValue &Lo, SDValue &Hi);
463+
void ExpandIntRes_AVG (SDNode *N, SDValue &Lo, SDValue &Hi);
463464
void ExpandIntRes_ADDSUBSAT (SDNode *N, SDValue &Lo, SDValue &Hi);
464465
void ExpandIntRes_SHLSAT (SDNode *N, SDValue &Lo, SDValue &Hi);
465466
void ExpandIntRes_MULFIX (SDNode *N, SDValue &Lo, SDValue &Hi);

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,10 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
369369
case ISD::ABS:
370370
case ISD::ABDS:
371371
case ISD::ABDU:
372+
case ISD::AVGCEILS:
373+
case ISD::AVGCEILU:
374+
case ISD::AVGFLOORS:
375+
case ISD::AVGFLOORU:
372376
case ISD::BSWAP:
373377
case ISD::BITREVERSE:
374378
case ISD::CTLZ:
@@ -916,6 +920,15 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
916920
return;
917921
}
918922
break;
923+
case ISD::AVGCEILS:
924+
case ISD::AVGCEILU:
925+
case ISD::AVGFLOORS:
926+
case ISD::AVGFLOORU:
927+
if (SDValue Expanded = TLI.expandAVG(Node, DAG)) {
928+
Results.push_back(Expanded);
929+
return;
930+
}
931+
break;
919932
case ISD::BITREVERSE:
920933
ExpandBITREVERSE(Node, Results);
921934
return;

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) {
125125
break;
126126
case ISD::ADD:
127127
case ISD::AND:
128+
case ISD::AVGCEILS:
129+
case ISD::AVGCEILU:
130+
case ISD::AVGFLOORS:
131+
case ISD::AVGFLOORU:
128132
case ISD::FADD:
129133
case ISD::FCOPYSIGN:
130134
case ISD::FDIV:
@@ -1171,7 +1175,12 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
11711175
case ISD::MUL: case ISD::VP_MUL:
11721176
case ISD::MULHS:
11731177
case ISD::MULHU:
1174-
case ISD::FADD: case ISD::VP_FADD:
1178+
case ISD::AVGCEILS:
1179+
case ISD::AVGCEILU:
1180+
case ISD::AVGFLOORS:
1181+
case ISD::AVGFLOORU:
1182+
case ISD::FADD:
1183+
case ISD::VP_FADD:
11751184
case ISD::FSUB: case ISD::VP_FSUB:
11761185
case ISD::FMUL: case ISD::VP_FMUL:
11771186
case ISD::FMINNUM: case ISD::VP_FMINNUM:

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4560,8 +4560,15 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
45604560
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
45614561
// SRA X, C -> adds C sign bits.
45624562
if (const APInt *ShAmt =
4563-
getValidMinimumShiftAmountConstant(Op, DemandedElts))
4563+
getValidMinimumShiftAmountConstant(Op, DemandedElts)) {
45644564
Tmp = std::min<uint64_t>(Tmp + ShAmt->getZExtValue(), VTBits);
4565+
} else {
4566+
KnownBits KnownAmt =
4567+
computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
4568+
if (KnownAmt.isConstant() && KnownAmt.getConstant().ult(VTBits))
4569+
Tmp = std::min<uint64_t>(Tmp + KnownAmt.getConstant().getZExtValue(),
4570+
VTBits);
4571+
}
45654572
return Tmp;
45664573
case ISD::SHL:
45674574
if (const APInt *ShAmt =
@@ -4752,6 +4759,13 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
47524759
(VTBits - SignBitsOp0 + 1) + (VTBits - SignBitsOp1 + 1);
47534760
return OutValidBits > VTBits ? 1 : VTBits - OutValidBits + 1;
47544761
}
4762+
case ISD::AVGCEILS:
4763+
case ISD::AVGFLOORS:
4764+
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
4765+
if (Tmp == 1)
4766+
break; // Early out.
4767+
Tmp2 = ComputeNumSignBits(Op.getOperand(1), DemandedElts, Depth + 1);
4768+
return std::min(Tmp, Tmp2);
47554769
case ISD::SREM:
47564770
// The sign bit is the LHS's sign bit, except when the result of the
47574771
// remainder is zero. The magnitude of the result should be less than or

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -947,11 +947,11 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedVectorElts(
947947

948948
// Attempt to form ext(avgfloor(A, B)) from shr(add(ext(A), ext(B)), 1).
949949
// or to form ext(avgceil(A, B)) from shr(add(ext(A), ext(B), 1), 1).
950-
static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
950+
static SDValue combineShiftToAVG(SDValue Op,
951+
TargetLowering::TargetLoweringOpt &TLO,
951952
const TargetLowering &TLI,
952953
const APInt &DemandedBits,
953-
const APInt &DemandedElts,
954-
unsigned Depth) {
954+
const APInt &DemandedElts, unsigned Depth) {
955955
assert((Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SRA) &&
956956
"SRL or SRA node is required here!");
957957
// Is the right shift using an immediate value of 1?
@@ -1002,6 +1002,7 @@ static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
10021002
// If the shift is unsigned (srl):
10031003
// - Needs >= 1 zero bit for both operands.
10041004
// - Needs 1 demanded bit zero and >= 2 sign bits.
1005+
SelectionDAG &DAG = TLO.DAG;
10051006
unsigned ShiftOpc = Op.getOpcode();
10061007
bool IsSigned = false;
10071008
unsigned KnownBits;
@@ -1057,10 +1058,10 @@ static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
10571058
EVT NVT = EVT::getIntegerVT(*DAG.getContext(), llvm::bit_ceil(MinWidth));
10581059
if (VT.isVector())
10591060
NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
1060-
if (!TLI.isOperationLegalOrCustom(AVGOpc, NVT)) {
1061+
if (TLO.LegalOperations() && !TLI.isOperationLegal(AVGOpc, NVT)) {
10611062
// If we could not transform, and (both) adds are nuw/nsw, we can use the
10621063
// larger type size to do the transform.
1063-
if (!TLI.isOperationLegalOrCustom(AVGOpc, VT))
1064+
if (TLO.LegalOperations() && !TLI.isOperationLegal(AVGOpc, VT))
10641065
return SDValue();
10651066
if (DAG.willNotOverflowAdd(IsSigned, Add.getOperand(0),
10661067
Add.getOperand(1)) &&
@@ -1908,11 +1909,6 @@ bool TargetLowering::SimplifyDemandedBits(
19081909
SDValue Op1 = Op.getOperand(1);
19091910
EVT ShiftVT = Op1.getValueType();
19101911

1911-
// Try to match AVG patterns.
1912-
if (SDValue AVG = combineShiftToAVG(Op, TLO.DAG, *this, DemandedBits,
1913-
DemandedElts, Depth + 1))
1914-
return TLO.CombineTo(Op, AVG);
1915-
19161912
KnownBits KnownSA = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1);
19171913
if (KnownSA.isConstant() && KnownSA.getConstant().ult(BitWidth)) {
19181914
unsigned ShAmt = KnownSA.getConstant().getZExtValue();
@@ -1994,6 +1990,12 @@ bool TargetLowering::SimplifyDemandedBits(
19941990
// shift amounts.
19951991
Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
19961992
}
1993+
1994+
// Try to match AVG patterns (after shift simplification).
1995+
if (SDValue AVG = combineShiftToAVG(Op, TLO, *this, DemandedBits,
1996+
DemandedElts, Depth + 1))
1997+
return TLO.CombineTo(Op, AVG);
1998+
19971999
break;
19982000
}
19992001
case ISD::SRA: {
@@ -2015,11 +2017,6 @@ bool TargetLowering::SimplifyDemandedBits(
20152017
if (DemandedBits.isOne())
20162018
return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1));
20172019

2018-
// Try to match AVG patterns.
2019-
if (SDValue AVG = combineShiftToAVG(Op, TLO.DAG, *this, DemandedBits,
2020-
DemandedElts, Depth + 1))
2021-
return TLO.CombineTo(Op, AVG);
2022-
20232020
KnownBits KnownSA = TLO.DAG.computeKnownBits(Op1, DemandedElts, Depth + 1);
20242021
if (KnownSA.isConstant() && KnownSA.getConstant().ult(BitWidth)) {
20252022
unsigned ShAmt = KnownSA.getConstant().getZExtValue();
@@ -2106,6 +2103,12 @@ bool TargetLowering::SimplifyDemandedBits(
21062103
}
21072104
}
21082105
}
2106+
2107+
// Try to match AVG patterns (after shift simplification).
2108+
if (SDValue AVG = combineShiftToAVG(Op, TLO, *this, DemandedBits,
2109+
DemandedElts, Depth + 1))
2110+
return TLO.CombineTo(Op, AVG);
2111+
21092112
break;
21102113
}
21112114
case ISD::FSHL:
@@ -9203,6 +9206,49 @@ SDValue TargetLowering::expandABD(SDNode *N, SelectionDAG &DAG) const {
92039206
DAG.getNode(ISD::SUB, dl, VT, RHS, LHS));
92049207
}
92059208

9209+
SDValue TargetLowering::expandAVG(SDNode *N, SelectionDAG &DAG) const {
9210+
SDLoc dl(N);
9211+
EVT VT = N->getValueType(0);
9212+
SDValue LHS = N->getOperand(0);
9213+
SDValue RHS = N->getOperand(1);
9214+
9215+
unsigned Opc = N->getOpcode();
9216+
bool IsFloor = Opc == ISD::AVGFLOORS || Opc == ISD::AVGFLOORU;
9217+
bool IsSigned = Opc == ISD::AVGCEILS || Opc == ISD::AVGFLOORS;
9218+
unsigned ShiftOpc = IsSigned ? ISD::SRA : ISD::SRL;
9219+
assert((Opc == ISD::AVGFLOORS || Opc == ISD::AVGCEILS ||
9220+
Opc == ISD::AVGFLOORU || Opc == ISD::AVGCEILU) &&
9221+
"Unknown AVG node");
9222+
9223+
// If the operands are already extended, we can add+shift.
9224+
bool IsExt =
9225+
(IsSigned && DAG.ComputeNumSignBits(LHS) >= 2 &&
9226+
DAG.ComputeNumSignBits(RHS) >= 2) ||
9227+
(!IsSigned && DAG.computeKnownBits(LHS).countMinLeadingZeros() >= 1 &&
9228+
DAG.computeKnownBits(RHS).countMinLeadingZeros() >= 1);
9229+
if (IsExt) {
9230+
SDValue Sum = DAG.getNode(ISD::ADD, dl, VT, LHS, RHS);
9231+
if (!IsFloor)
9232+
Sum = DAG.getNode(ISD::ADD, dl, VT, Sum, DAG.getConstant(1, dl, VT));
9233+
return DAG.getNode(ShiftOpc, dl, VT, Sum,
9234+
DAG.getShiftAmountConstant(1, VT, dl));
9235+
}
9236+
9237+
// avgceils(lhs, rhs) -> sub(or(lhs,rhs),ashr(xor(lhs,rhs),1))
9238+
// avgceilu(lhs, rhs) -> sub(or(lhs,rhs),lshr(xor(lhs,rhs),1))
9239+
// avgfloors(lhs, rhs) -> add(and(lhs,rhs),ashr(xor(lhs,rhs),1))
9240+
// avgflooru(lhs, rhs) -> add(and(lhs,rhs),lshr(xor(lhs,rhs),1))
9241+
unsigned SumOpc = IsFloor ? ISD::ADD : ISD::SUB;
9242+
unsigned SignOpc = IsFloor ? ISD::AND : ISD::OR;
9243+
LHS = DAG.getFreeze(LHS);
9244+
RHS = DAG.getFreeze(RHS);
9245+
SDValue Sign = DAG.getNode(SignOpc, dl, VT, LHS, RHS);
9246+
SDValue Xor = DAG.getNode(ISD::XOR, dl, VT, LHS, RHS);
9247+
SDValue Shift =
9248+
DAG.getNode(ShiftOpc, dl, VT, Xor, DAG.getShiftAmountConstant(1, VT, dl));
9249+
return DAG.getNode(SumOpc, dl, VT, Sign, Shift);
9250+
}
9251+
92069252
SDValue TargetLowering::expandBSWAP(SDNode *N, SelectionDAG &DAG) const {
92079253
SDLoc dl(N);
92089254
EVT VT = N->getValueType(0);

0 commit comments

Comments
 (0)