Skip to content

Commit ea2ee5d

Browse files
authored
[DAG] Add legalization handling for AVGCEIL/AVGFLOOR nodes (#92096)
Always match AVG patterns pre-legalization, and use TargetLowering::expandAVG to expand again during legalization. I've removed the X86 custom AVGCEILU pattern detection and replaced with combines to try and convert other AVG nodes to AVGCEILU.
1 parent ec16f44 commit ea2ee5d

19 files changed

+3519
-9990
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5345,6 +5345,11 @@ class TargetLowering : public TargetLoweringBase {
53455345
/// \returns The expansion result or SDValue() if it fails.
53465346
SDValue expandABD(SDNode *N, SelectionDAG &DAG) const;
53475347

5348+
/// Expand vector/scalar AVGCEILS/AVGCEILU/AVGFLOORS/AVGFLOORU nodes.
5349+
/// \param N Node to expand
5350+
/// \returns The expansion result or SDValue() if it fails.
5351+
SDValue expandAVG(SDNode *N, SelectionDAG &DAG) const;
5352+
53485353
/// Expand BSWAP nodes. Expands scalar/vector BSWAP nodes with i16/i32/i64
53495354
/// scalar types. Returns SDValue() if expand fails.
53505355
/// \param N Node to expand

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2575,13 +2575,13 @@ SDValue DAGCombiner::foldSubToAvg(SDNode *N, const SDLoc &DL) {
25752575
EVT VT = N0.getValueType();
25762576
SDValue A, B;
25772577

2578-
if (hasOperation(ISD::AVGCEILU, VT) &&
2578+
if ((!LegalOperations || hasOperation(ISD::AVGCEILU, VT)) &&
25792579
sd_match(N, m_Sub(m_Or(m_Value(A), m_Value(B)),
25802580
m_Srl(m_Xor(m_Deferred(A), m_Deferred(B)),
25812581
m_SpecificInt(1))))) {
25822582
return DAG.getNode(ISD::AVGCEILU, DL, VT, A, B);
25832583
}
2584-
if (hasOperation(ISD::AVGCEILS, VT) &&
2584+
if ((!LegalOperations || hasOperation(ISD::AVGCEILS, VT)) &&
25852585
sd_match(N, m_Sub(m_Or(m_Value(A), m_Value(B)),
25862586
m_Sra(m_Xor(m_Deferred(A), m_Deferred(B)),
25872587
m_SpecificInt(1))))) {
@@ -2947,13 +2947,13 @@ SDValue DAGCombiner::foldAddToAvg(SDNode *N, const SDLoc &DL) {
29472947
EVT VT = N0.getValueType();
29482948
SDValue A, B;
29492949

2950-
if (hasOperation(ISD::AVGFLOORU, VT) &&
2950+
if ((!LegalOperations || hasOperation(ISD::AVGFLOORU, VT)) &&
29512951
sd_match(N, m_Add(m_And(m_Value(A), m_Value(B)),
29522952
m_Srl(m_Xor(m_Deferred(A), m_Deferred(B)),
29532953
m_SpecificInt(1))))) {
29542954
return DAG.getNode(ISD::AVGFLOORU, DL, VT, A, B);
29552955
}
2956-
if (hasOperation(ISD::AVGFLOORS, VT) &&
2956+
if ((!LegalOperations || hasOperation(ISD::AVGFLOORS, VT)) &&
29572957
sd_match(N, m_Add(m_And(m_Value(A), m_Value(B)),
29582958
m_Sra(m_Xor(m_Deferred(A), m_Deferred(B)),
29592959
m_SpecificInt(1))))) {
@@ -5253,6 +5253,22 @@ SDValue DAGCombiner::visitAVG(SDNode *N) {
52535253
SDValue AvgCeilU = DAG.getNode(ISD::AVGCEILU, DL, A.getValueType(), A, B);
52545254
return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, AvgCeilU);
52555255
}
5256+
5257+
// Fold avgflooru(x,y) -> avgceilu(x,y-1) iff y != 0
5258+
// Fold avgflooru(x,y) -> avgceilu(x-1,y) iff x != 0
5259+
// Check if avgflooru isn't legal/custom but avgceilu is.
5260+
if (Opcode == ISD::AVGFLOORU && !hasOperation(ISD::AVGFLOORU, VT) &&
5261+
(!LegalOperations || hasOperation(ISD::AVGCEILU, VT))) {
5262+
if (DAG.isKnownNeverZero(N1))
5263+
return DAG.getNode(
5264+
ISD::AVGCEILU, DL, VT, N0,
5265+
DAG.getNode(ISD::ADD, DL, VT, N1, DAG.getAllOnesConstant(DL, VT)));
5266+
if (DAG.isKnownNeverZero(N0))
5267+
return DAG.getNode(
5268+
ISD::AVGCEILU, DL, VT, N1,
5269+
DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getAllOnesConstant(DL, VT)));
5270+
}
5271+
52565272
return SDValue();
52575273
}
52585274

llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3059,6 +3059,13 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
30593059
if ((Tmp1 = TLI.expandABD(Node, DAG)))
30603060
Results.push_back(Tmp1);
30613061
break;
3062+
case ISD::AVGCEILS:
3063+
case ISD::AVGCEILU:
3064+
case ISD::AVGFLOORS:
3065+
case ISD::AVGFLOORU:
3066+
if ((Tmp1 = TLI.expandAVG(Node, DAG)))
3067+
Results.push_back(Tmp1);
3068+
break;
30623069
case ISD::CTPOP:
30633070
if ((Tmp1 = TLI.expandCTPOP(Node, DAG)))
30643071
Results.push_back(Tmp1);

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,13 +188,17 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
188188
case ISD::VP_SUB:
189189
case ISD::VP_MUL: Res = PromoteIntRes_SimpleIntBinOp(N); break;
190190

191+
case ISD::AVGCEILS:
192+
case ISD::AVGFLOORS:
191193
case ISD::VP_SMIN:
192194
case ISD::VP_SMAX:
193195
case ISD::SDIV:
194196
case ISD::SREM:
195197
case ISD::VP_SDIV:
196198
case ISD::VP_SREM: Res = PromoteIntRes_SExtIntBinOp(N); break;
197199

200+
case ISD::AVGCEILU:
201+
case ISD::AVGFLOORU:
198202
case ISD::VP_UMIN:
199203
case ISD::VP_UMAX:
200204
case ISD::UDIV:
@@ -2818,6 +2822,11 @@ void DAGTypeLegalizer::ExpandIntegerResult(SDNode *N, unsigned ResNo) {
28182822
case ISD::SSHLSAT:
28192823
case ISD::USHLSAT: ExpandIntRes_SHLSAT(N, Lo, Hi); break;
28202824

2825+
case ISD::AVGCEILS:
2826+
case ISD::AVGCEILU:
2827+
case ISD::AVGFLOORS:
2828+
case ISD::AVGFLOORU: ExpandIntRes_AVG(N, Lo, Hi); break;
2829+
28212830
case ISD::SMULFIX:
28222831
case ISD::SMULFIXSAT:
28232832
case ISD::UMULFIX:
@@ -4120,6 +4129,11 @@ void DAGTypeLegalizer::ExpandIntRes_READCOUNTER(SDNode *N, SDValue &Lo,
41204129
ReplaceValueWith(SDValue(N, 1), R.getValue(2));
41214130
}
41224131

4132+
void DAGTypeLegalizer::ExpandIntRes_AVG(SDNode *N, SDValue &Lo, SDValue &Hi) {
4133+
SDValue Result = TLI.expandAVG(N, DAG);
4134+
SplitInteger(Result, Lo, Hi);
4135+
}
4136+
41234137
void DAGTypeLegalizer::ExpandIntRes_ADDSUBSAT(SDNode *N, SDValue &Lo,
41244138
SDValue &Hi) {
41254139
SDValue Result = TLI.expandAddSubSat(N, DAG);

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
479479
void ExpandIntRes_SADDSUBO (SDNode *N, SDValue &Lo, SDValue &Hi);
480480
void ExpandIntRes_UADDSUBO (SDNode *N, SDValue &Lo, SDValue &Hi);
481481
void ExpandIntRes_XMULO (SDNode *N, SDValue &Lo, SDValue &Hi);
482+
void ExpandIntRes_AVG (SDNode *N, SDValue &Lo, SDValue &Hi);
482483
void ExpandIntRes_ADDSUBSAT (SDNode *N, SDValue &Lo, SDValue &Hi);
483484
void ExpandIntRes_SHLSAT (SDNode *N, SDValue &Lo, SDValue &Hi);
484485
void ExpandIntRes_MULFIX (SDNode *N, SDValue &Lo, SDValue &Hi);

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,10 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
369369
case ISD::ABS:
370370
case ISD::ABDS:
371371
case ISD::ABDU:
372+
case ISD::AVGCEILS:
373+
case ISD::AVGCEILU:
374+
case ISD::AVGFLOORS:
375+
case ISD::AVGFLOORU:
372376
case ISD::BSWAP:
373377
case ISD::BITREVERSE:
374378
case ISD::CTLZ:
@@ -918,6 +922,15 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
918922
return;
919923
}
920924
break;
925+
case ISD::AVGCEILS:
926+
case ISD::AVGCEILU:
927+
case ISD::AVGFLOORS:
928+
case ISD::AVGFLOORU:
929+
if (SDValue Expanded = TLI.expandAVG(Node, DAG)) {
930+
Results.push_back(Expanded);
931+
return;
932+
}
933+
break;
921934
case ISD::BITREVERSE:
922935
ExpandBITREVERSE(Node, Results);
923936
return;

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,10 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) {
126126
break;
127127
case ISD::ADD:
128128
case ISD::AND:
129+
case ISD::AVGCEILS:
130+
case ISD::AVGCEILU:
131+
case ISD::AVGFLOORS:
132+
case ISD::AVGFLOORU:
129133
case ISD::FADD:
130134
case ISD::FCOPYSIGN:
131135
case ISD::FDIV:
@@ -1173,6 +1177,10 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
11731177
case ISD::MUL: case ISD::VP_MUL:
11741178
case ISD::MULHS:
11751179
case ISD::MULHU:
1180+
case ISD::AVGCEILS:
1181+
case ISD::AVGCEILU:
1182+
case ISD::AVGFLOORS:
1183+
case ISD::AVGFLOORU:
11761184
case ISD::FADD: case ISD::VP_FADD:
11771185
case ISD::FSUB: case ISD::VP_FSUB:
11781186
case ISD::FMUL: case ISD::VP_FMUL:

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -951,11 +951,11 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedVectorElts(
951951

952952
// Attempt to form ext(avgfloor(A, B)) from shr(add(ext(A), ext(B)), 1).
953953
// or to form ext(avgceil(A, B)) from shr(add(ext(A), ext(B), 1), 1).
954-
static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
954+
static SDValue combineShiftToAVG(SDValue Op,
955+
TargetLowering::TargetLoweringOpt &TLO,
955956
const TargetLowering &TLI,
956957
const APInt &DemandedBits,
957-
const APInt &DemandedElts,
958-
unsigned Depth) {
958+
const APInt &DemandedElts, unsigned Depth) {
959959
assert((Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SRA) &&
960960
"SRL or SRA node is required here!");
961961
// Is the right shift using an immediate value of 1?
@@ -1006,6 +1006,7 @@ static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
10061006
// If the shift is unsigned (srl):
10071007
// - Needs >= 1 zero bit for both operands.
10081008
// - Needs 1 demanded bit zero and >= 2 sign bits.
1009+
SelectionDAG &DAG = TLO.DAG;
10091010
unsigned ShiftOpc = Op.getOpcode();
10101011
bool IsSigned = false;
10111012
unsigned KnownBits;
@@ -1061,10 +1062,10 @@ static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
10611062
EVT NVT = EVT::getIntegerVT(*DAG.getContext(), llvm::bit_ceil(MinWidth));
10621063
if (VT.isVector())
10631064
NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
1064-
if (!TLI.isOperationLegalOrCustom(AVGOpc, NVT)) {
1065+
if (TLO.LegalOperations() && !TLI.isOperationLegal(AVGOpc, NVT)) {
10651066
// If we could not transform, and (both) adds are nuw/nsw, we can use the
10661067
// larger type size to do the transform.
1067-
if (!TLI.isOperationLegalOrCustom(AVGOpc, VT))
1068+
if (TLO.LegalOperations() && !TLI.isOperationLegal(AVGOpc, VT))
10681069
return SDValue();
10691070
if (DAG.willNotOverflowAdd(IsSigned, Add.getOperand(0),
10701071
Add.getOperand(1)) &&
@@ -2002,7 +2003,7 @@ bool TargetLowering::SimplifyDemandedBits(
20022003
}
20032004

20042005
// Try to match AVG patterns (after shift simplification).
2005-
if (SDValue AVG = combineShiftToAVG(Op, TLO.DAG, *this, DemandedBits,
2006+
if (SDValue AVG = combineShiftToAVG(Op, TLO, *this, DemandedBits,
20062007
DemandedElts, Depth + 1))
20072008
return TLO.CombineTo(Op, AVG);
20082009

@@ -2113,7 +2114,7 @@ bool TargetLowering::SimplifyDemandedBits(
21132114
}
21142115

21152116
// Try to match AVG patterns (after shift simplification).
2116-
if (SDValue AVG = combineShiftToAVG(Op, TLO.DAG, *this, DemandedBits,
2117+
if (SDValue AVG = combineShiftToAVG(Op, TLO, *this, DemandedBits,
21172118
DemandedElts, Depth + 1))
21182119
return TLO.CombineTo(Op, AVG);
21192120

@@ -9225,6 +9226,49 @@ SDValue TargetLowering::expandABD(SDNode *N, SelectionDAG &DAG) const {
92259226
DAG.getNode(ISD::SUB, dl, VT, RHS, LHS));
92269227
}
92279228

9229+
SDValue TargetLowering::expandAVG(SDNode *N, SelectionDAG &DAG) const {
9230+
SDLoc dl(N);
9231+
EVT VT = N->getValueType(0);
9232+
SDValue LHS = N->getOperand(0);
9233+
SDValue RHS = N->getOperand(1);
9234+
9235+
unsigned Opc = N->getOpcode();
9236+
bool IsFloor = Opc == ISD::AVGFLOORS || Opc == ISD::AVGFLOORU;
9237+
bool IsSigned = Opc == ISD::AVGCEILS || Opc == ISD::AVGFLOORS;
9238+
unsigned ShiftOpc = IsSigned ? ISD::SRA : ISD::SRL;
9239+
assert((Opc == ISD::AVGFLOORS || Opc == ISD::AVGCEILS ||
9240+
Opc == ISD::AVGFLOORU || Opc == ISD::AVGCEILU) &&
9241+
"Unknown AVG node");
9242+
9243+
// If the operands are already extended, we can add+shift.
9244+
bool IsExt =
9245+
(IsSigned && DAG.ComputeNumSignBits(LHS) >= 2 &&
9246+
DAG.ComputeNumSignBits(RHS) >= 2) ||
9247+
(!IsSigned && DAG.computeKnownBits(LHS).countMinLeadingZeros() >= 1 &&
9248+
DAG.computeKnownBits(RHS).countMinLeadingZeros() >= 1);
9249+
if (IsExt) {
9250+
SDValue Sum = DAG.getNode(ISD::ADD, dl, VT, LHS, RHS);
9251+
if (!IsFloor)
9252+
Sum = DAG.getNode(ISD::ADD, dl, VT, Sum, DAG.getConstant(1, dl, VT));
9253+
return DAG.getNode(ShiftOpc, dl, VT, Sum,
9254+
DAG.getShiftAmountConstant(1, VT, dl));
9255+
}
9256+
9257+
// avgceils(lhs, rhs) -> sub(or(lhs,rhs),ashr(xor(lhs,rhs),1))
9258+
// avgceilu(lhs, rhs) -> sub(or(lhs,rhs),lshr(xor(lhs,rhs),1))
9259+
// avgfloors(lhs, rhs) -> add(and(lhs,rhs),ashr(xor(lhs,rhs),1))
9260+
// avgflooru(lhs, rhs) -> add(and(lhs,rhs),lshr(xor(lhs,rhs),1))
9261+
unsigned SumOpc = IsFloor ? ISD::ADD : ISD::SUB;
9262+
unsigned SignOpc = IsFloor ? ISD::AND : ISD::OR;
9263+
LHS = DAG.getFreeze(LHS);
9264+
RHS = DAG.getFreeze(RHS);
9265+
SDValue Sign = DAG.getNode(SignOpc, dl, VT, LHS, RHS);
9266+
SDValue Xor = DAG.getNode(ISD::XOR, dl, VT, LHS, RHS);
9267+
SDValue Shift =
9268+
DAG.getNode(ShiftOpc, dl, VT, Xor, DAG.getShiftAmountConstant(1, VT, dl));
9269+
return DAG.getNode(SumOpc, dl, VT, Sign, Shift);
9270+
}
9271+
92289272
SDValue TargetLowering::expandBSWAP(SDNode *N, SelectionDAG &DAG) const {
92299273
SDLoc dl(N);
92309274
EVT VT = N->getValueType(0);

0 commit comments

Comments
 (0)