Skip to content

Commit 2f9a4fb

Browse files
committed
[WIP][DAG] Add legalization handling for AVGCEIL/AVGFLOOR nodes
Still WIP, but I wanted to get some visibility to other teams. Always match AVG patterns pre-legalization, and use TargetLowering::expandAVG to expand again during legalization. I've removed the X86 custom AVGCEILU pattern detection and replaced with combines to try and convert other AVG nodes to AVGCEILU
1 parent cd5ee27 commit 2f9a4fb

21 files changed

+3574
-10016
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5332,6 +5332,11 @@ class TargetLowering : public TargetLoweringBase {
53325332
/// \returns The expansion result or SDValue() if it fails.
53335333
SDValue expandABD(SDNode *N, SelectionDAG &DAG) const;
53345334

5335+
/// Expand vector/scalar AVGCEILS/AVGCEILU/AVGFLOORS/AVGFLOORU nodes.
5336+
/// \param N Node to expand
5337+
/// \returns The expansion result or SDValue() if it fails.
5338+
SDValue expandAVG(SDNode *N, SelectionDAG &DAG) const;
5339+
53355340
/// Expand BSWAP nodes. Expands scalar/vector BSWAP nodes with i16/i32/i64
53365341
/// scalar types. Returns SDValue() if expand fails.
53375342
/// \param N Node to expand

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2578,13 +2578,13 @@ SDValue DAGCombiner::foldSubToAvg(SDNode *N, const SDLoc &DL) {
25782578
EVT VT = N0.getValueType();
25792579
SDValue A, B;
25802580

2581-
if (hasOperation(ISD::AVGCEILU, VT) &&
2581+
if ((!LegalOperations || hasOperation(ISD::AVGCEILU, VT)) &&
25822582
sd_match(N, m_Sub(m_Or(m_Value(A), m_Value(B)),
25832583
m_Srl(m_Xor(m_Deferred(A), m_Deferred(B)),
25842584
m_SpecificInt(1))))) {
25852585
return DAG.getNode(ISD::AVGCEILU, DL, VT, A, B);
25862586
}
2587-
if (hasOperation(ISD::AVGCEILS, VT) &&
2587+
if ((!LegalOperations || hasOperation(ISD::AVGCEILS, VT)) &&
25882588
sd_match(N, m_Sub(m_Or(m_Value(A), m_Value(B)),
25892589
m_Sra(m_Xor(m_Deferred(A), m_Deferred(B)),
25902590
m_SpecificInt(1))))) {
@@ -2950,13 +2950,13 @@ SDValue DAGCombiner::foldAddToAvg(SDNode *N, const SDLoc &DL) {
29502950
EVT VT = N0.getValueType();
29512951
SDValue A, B;
29522952

2953-
if (hasOperation(ISD::AVGFLOORU, VT) &&
2953+
if ((!LegalOperations || hasOperation(ISD::AVGFLOORU, VT)) &&
29542954
sd_match(N, m_Add(m_And(m_Value(A), m_Value(B)),
29552955
m_Srl(m_Xor(m_Deferred(A), m_Deferred(B)),
29562956
m_SpecificInt(1))))) {
29572957
return DAG.getNode(ISD::AVGFLOORU, DL, VT, A, B);
29582958
}
2959-
if (hasOperation(ISD::AVGFLOORS, VT) &&
2959+
if ((!LegalOperations || hasOperation(ISD::AVGFLOORS, VT)) &&
29602960
sd_match(N, m_Add(m_And(m_Value(A), m_Value(B)),
29612961
m_Sra(m_Xor(m_Deferred(A), m_Deferred(B)),
29622962
m_SpecificInt(1))))) {
@@ -5234,6 +5234,21 @@ SDValue DAGCombiner::visitAVG(SDNode *N) {
52345234
if (N0 == N1 && Level >= AfterLegalizeTypes)
52355235
return N0;
52365236

5237+
// Fold avgflooru(x,y) -> avgceilu(x,y-1) iff y != 0
5238+
// Fold avgflooru(x,y) -> avgceilu(x-1,y) iff x != 0
5239+
// Check if avgflooru isn't legal/custom but avgceilu is.
5240+
if (Opcode == ISD::AVGFLOORU && !hasOperation(ISD::AVGFLOORU, VT) &&
5241+
(!LegalOperations || hasOperation(ISD::AVGCEILU, VT))) {
5242+
if (DAG.isKnownNeverZero(N0))
5243+
return DAG.getNode(
5244+
ISD::AVGCEILU, DL, VT, N1,
5245+
DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getAllOnesConstant(DL, VT)));
5246+
if (DAG.isKnownNeverZero(N1))
5247+
return DAG.getNode(
5248+
ISD::AVGCEILU, DL, VT, N0,
5249+
DAG.getNode(ISD::ADD, DL, VT, N1, DAG.getAllOnesConstant(DL, VT)));
5250+
}
5251+
52375252
// TODO If we use avg for scalars anywhere, we can add (avgfl x, 0) -> x >> 1
52385253

52395254
return SDValue();

llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3047,6 +3047,13 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
30473047
if ((Tmp1 = TLI.expandABD(Node, DAG)))
30483048
Results.push_back(Tmp1);
30493049
break;
3050+
case ISD::AVGCEILS:
3051+
case ISD::AVGCEILU:
3052+
case ISD::AVGFLOORS:
3053+
case ISD::AVGFLOORU:
3054+
if ((Tmp1 = TLI.expandAVG(Node, DAG)))
3055+
Results.push_back(Tmp1);
3056+
break;
30503057
case ISD::CTPOP:
30513058
if ((Tmp1 = TLI.expandCTPOP(Node, DAG)))
30523059
Results.push_back(Tmp1);

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,13 +188,17 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
188188
case ISD::VP_SUB:
189189
case ISD::VP_MUL: Res = PromoteIntRes_SimpleIntBinOp(N); break;
190190

191+
case ISD::AVGCEILS:
192+
case ISD::AVGFLOORS:
191193
case ISD::VP_SMIN:
192194
case ISD::VP_SMAX:
193195
case ISD::SDIV:
194196
case ISD::SREM:
195197
case ISD::VP_SDIV:
196198
case ISD::VP_SREM: Res = PromoteIntRes_SExtIntBinOp(N); break;
197199

200+
case ISD::AVGCEILU:
201+
case ISD::AVGFLOORU:
198202
case ISD::VP_UMIN:
199203
case ISD::VP_UMAX:
200204
case ISD::UDIV:
@@ -2775,6 +2779,11 @@ void DAGTypeLegalizer::ExpandIntegerResult(SDNode *N, unsigned ResNo) {
27752779
case ISD::SSHLSAT:
27762780
case ISD::USHLSAT: ExpandIntRes_SHLSAT(N, Lo, Hi); break;
27772781

2782+
case ISD::AVGCEILS:
2783+
case ISD::AVGCEILU:
2784+
case ISD::AVGFLOORS:
2785+
case ISD::AVGFLOORU: ExpandIntRes_AVG(N, Lo, Hi); break;
2786+
27782787
case ISD::SMULFIX:
27792788
case ISD::SMULFIXSAT:
27802789
case ISD::UMULFIX:
@@ -4077,6 +4086,11 @@ void DAGTypeLegalizer::ExpandIntRes_READCOUNTER(SDNode *N, SDValue &Lo,
40774086
ReplaceValueWith(SDValue(N, 1), R.getValue(2));
40784087
}
40794088

4089+
void DAGTypeLegalizer::ExpandIntRes_AVG(SDNode *N, SDValue &Lo, SDValue &Hi) {
4090+
SDValue Result = TLI.expandAVG(N, DAG);
4091+
SplitInteger(Result, Lo, Hi);
4092+
}
4093+
40804094
void DAGTypeLegalizer::ExpandIntRes_ADDSUBSAT(SDNode *N, SDValue &Lo,
40814095
SDValue &Hi) {
40824096
SDValue Result = TLI.expandAddSubSat(N, DAG);

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
460460
void ExpandIntRes_SADDSUBO (SDNode *N, SDValue &Lo, SDValue &Hi);
461461
void ExpandIntRes_UADDSUBO (SDNode *N, SDValue &Lo, SDValue &Hi);
462462
void ExpandIntRes_XMULO (SDNode *N, SDValue &Lo, SDValue &Hi);
463+
void ExpandIntRes_AVG (SDNode *N, SDValue &Lo, SDValue &Hi);
463464
void ExpandIntRes_ADDSUBSAT (SDNode *N, SDValue &Lo, SDValue &Hi);
464465
void ExpandIntRes_SHLSAT (SDNode *N, SDValue &Lo, SDValue &Hi);
465466
void ExpandIntRes_MULFIX (SDNode *N, SDValue &Lo, SDValue &Hi);

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,10 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
369369
case ISD::ABS:
370370
case ISD::ABDS:
371371
case ISD::ABDU:
372+
case ISD::AVGCEILS:
373+
case ISD::AVGCEILU:
374+
case ISD::AVGFLOORS:
375+
case ISD::AVGFLOORU:
372376
case ISD::BSWAP:
373377
case ISD::BITREVERSE:
374378
case ISD::CTLZ:
@@ -916,6 +920,15 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
916920
return;
917921
}
918922
break;
923+
case ISD::AVGCEILS:
924+
case ISD::AVGCEILU:
925+
case ISD::AVGFLOORS:
926+
case ISD::AVGFLOORU:
927+
if (SDValue Expanded = TLI.expandAVG(Node, DAG)) {
928+
Results.push_back(Expanded);
929+
return;
930+
}
931+
break;
919932
case ISD::BITREVERSE:
920933
ExpandBITREVERSE(Node, Results);
921934
return;

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) {
125125
break;
126126
case ISD::ADD:
127127
case ISD::AND:
128+
case ISD::AVGCEILS:
129+
case ISD::AVGCEILU:
130+
case ISD::AVGFLOORS:
131+
case ISD::AVGFLOORU:
128132
case ISD::FADD:
129133
case ISD::FCOPYSIGN:
130134
case ISD::FDIV:
@@ -1171,7 +1175,12 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
11711175
case ISD::MUL: case ISD::VP_MUL:
11721176
case ISD::MULHS:
11731177
case ISD::MULHU:
1174-
case ISD::FADD: case ISD::VP_FADD:
1178+
case ISD::AVGCEILS:
1179+
case ISD::AVGCEILU:
1180+
case ISD::AVGFLOORS:
1181+
case ISD::AVGFLOORU:
1182+
case ISD::FADD:
1183+
case ISD::VP_FADD:
11751184
case ISD::FSUB: case ISD::VP_FSUB:
11761185
case ISD::FMUL: case ISD::VP_FMUL:
11771186
case ISD::FMINNUM: case ISD::VP_FMINNUM:

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4752,6 +4752,13 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
47524752
(VTBits - SignBitsOp0 + 1) + (VTBits - SignBitsOp1 + 1);
47534753
return OutValidBits > VTBits ? 1 : VTBits - OutValidBits + 1;
47544754
}
4755+
case ISD::AVGCEILS:
4756+
case ISD::AVGFLOORS:
4757+
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
4758+
if (Tmp == 1)
4759+
break; // Early out.
4760+
Tmp2 = ComputeNumSignBits(Op.getOperand(1), DemandedElts, Depth + 1);
4761+
return std::min(Tmp, Tmp2);
47554762
case ISD::SREM:
47564763
// The sign bit is the LHS's sign bit, except when the result of the
47574764
// remainder is zero. The magnitude of the result should be less than or

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -947,11 +947,11 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedVectorElts(
947947

948948
// Attempt to form ext(avgfloor(A, B)) from shr(add(ext(A), ext(B)), 1).
949949
// or to form ext(avgceil(A, B)) from shr(add(ext(A), ext(B), 1), 1).
950-
static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
950+
static SDValue combineShiftToAVG(SDValue Op,
951+
TargetLowering::TargetLoweringOpt &TLO,
951952
const TargetLowering &TLI,
952953
const APInt &DemandedBits,
953-
const APInt &DemandedElts,
954-
unsigned Depth) {
954+
const APInt &DemandedElts, unsigned Depth) {
955955
assert((Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SRA) &&
956956
"SRL or SRA node is required here!");
957957
// Is the right shift using an immediate value of 1?
@@ -1002,6 +1002,7 @@ static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
10021002
// If the shift is unsigned (srl):
10031003
// - Needs >= 1 zero bit for both operands.
10041004
// - Needs 1 demanded bit zero and >= 2 sign bits.
1005+
SelectionDAG &DAG = TLO.DAG;
10051006
unsigned ShiftOpc = Op.getOpcode();
10061007
bool IsSigned = false;
10071008
unsigned KnownBits;
@@ -1057,10 +1058,10 @@ static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
10571058
EVT NVT = EVT::getIntegerVT(*DAG.getContext(), llvm::bit_ceil(MinWidth));
10581059
if (VT.isVector())
10591060
NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
1060-
if (!TLI.isOperationLegalOrCustom(AVGOpc, NVT)) {
1061+
if (TLO.LegalOperations() && !TLI.isOperationLegal(AVGOpc, NVT)) {
10611062
// If we could not transform, and (both) adds are nuw/nsw, we can use the
10621063
// larger type size to do the transform.
1063-
if (!TLI.isOperationLegalOrCustom(AVGOpc, VT))
1064+
if (TLO.LegalOperations() && !TLI.isOperationLegal(AVGOpc, VT))
10641065
return SDValue();
10651066
if (DAG.willNotOverflowAdd(IsSigned, Add.getOperand(0),
10661067
Add.getOperand(1)) &&
@@ -1907,11 +1908,6 @@ bool TargetLowering::SimplifyDemandedBits(
19071908
SDValue Op1 = Op.getOperand(1);
19081909
EVT ShiftVT = Op1.getValueType();
19091910

1910-
// Try to match AVG patterns.
1911-
if (SDValue AVG = combineShiftToAVG(Op, TLO.DAG, *this, DemandedBits,
1912-
DemandedElts, Depth + 1))
1913-
return TLO.CombineTo(Op, AVG);
1914-
19151911
if (const APInt *SA =
19161912
TLO.DAG.getValidShiftAmountConstant(Op, DemandedElts)) {
19171913
unsigned ShAmt = SA->getZExtValue();
@@ -1992,6 +1988,12 @@ bool TargetLowering::SimplifyDemandedBits(
19921988
// shift amounts.
19931989
Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth);
19941990
}
1991+
1992+
// Try to match AVG patterns (after shift simplification).
1993+
if (SDValue AVG = combineShiftToAVG(Op, TLO, *this, DemandedBits,
1994+
DemandedElts, Depth + 1))
1995+
return TLO.CombineTo(Op, AVG);
1996+
19951997
break;
19961998
}
19971999
case ISD::SRA: {
@@ -2013,11 +2015,6 @@ bool TargetLowering::SimplifyDemandedBits(
20132015
if (DemandedBits.isOne())
20142016
return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1));
20152017

2016-
// Try to match AVG patterns.
2017-
if (SDValue AVG = combineShiftToAVG(Op, TLO.DAG, *this, DemandedBits,
2018-
DemandedElts, Depth + 1))
2019-
return TLO.CombineTo(Op, AVG);
2020-
20212018
if (const APInt *SA =
20222019
TLO.DAG.getValidShiftAmountConstant(Op, DemandedElts)) {
20232020
unsigned ShAmt = SA->getZExtValue();
@@ -2103,6 +2100,12 @@ bool TargetLowering::SimplifyDemandedBits(
21032100
}
21042101
}
21052102
}
2103+
2104+
// Try to match AVG patterns (after shift simplification).
2105+
if (SDValue AVG = combineShiftToAVG(Op, TLO, *this, DemandedBits,
2106+
DemandedElts, Depth + 1))
2107+
return TLO.CombineTo(Op, AVG);
2108+
21062109
break;
21072110
}
21082111
case ISD::FSHL:
@@ -9200,6 +9203,49 @@ SDValue TargetLowering::expandABD(SDNode *N, SelectionDAG &DAG) const {
92009203
DAG.getNode(ISD::SUB, dl, VT, RHS, LHS));
92019204
}
92029205

9206+
SDValue TargetLowering::expandAVG(SDNode *N, SelectionDAG &DAG) const {
9207+
SDLoc dl(N);
9208+
EVT VT = N->getValueType(0);
9209+
SDValue LHS = N->getOperand(0);
9210+
SDValue RHS = N->getOperand(1);
9211+
9212+
unsigned Opc = N->getOpcode();
9213+
bool IsFloor = Opc == ISD::AVGFLOORS || Opc == ISD::AVGFLOORU;
9214+
bool IsSigned = Opc == ISD::AVGCEILS || Opc == ISD::AVGFLOORS;
9215+
unsigned ShiftOpc = IsSigned ? ISD::SRA : ISD::SRL;
9216+
assert((Opc == ISD::AVGFLOORS || Opc == ISD::AVGCEILS ||
9217+
Opc == ISD::AVGFLOORU || Opc == ISD::AVGCEILU) &&
9218+
"Unknown AVG node");
9219+
9220+
// If the operands are already extended, we can add+shift.
9221+
bool IsExt =
9222+
(IsSigned && DAG.ComputeNumSignBits(LHS) >= 2 &&
9223+
DAG.ComputeNumSignBits(RHS) >= 2) ||
9224+
(!IsSigned && DAG.computeKnownBits(LHS).countMinLeadingZeros() >= 1 &&
9225+
DAG.computeKnownBits(RHS).countMinLeadingZeros() >= 1);
9226+
if (IsExt) {
9227+
SDValue Sum = DAG.getNode(ISD::ADD, dl, VT, LHS, RHS);
9228+
if (!IsFloor)
9229+
Sum = DAG.getNode(ISD::ADD, dl, VT, Sum, DAG.getConstant(1, dl, VT));
9230+
return DAG.getNode(ShiftOpc, dl, VT, Sum,
9231+
DAG.getShiftAmountConstant(1, VT, dl));
9232+
}
9233+
9234+
// avgceils(lhs, rhs) -> sub(or(lhs,rhs),ashr(xor(lhs,rhs),1))
9235+
// avgceilu(lhs, rhs) -> sub(or(lhs,rhs),lshr(xor(lhs,rhs),1))
9236+
// avgfloors(lhs, rhs) -> add(and(lhs,rhs),ashr(xor(lhs,rhs),1))
9237+
// avgflooru(lhs, rhs) -> add(and(lhs,rhs),lshr(xor(lhs,rhs),1))
9238+
unsigned SumOpc = IsFloor ? ISD::ADD : ISD::SUB;
9239+
unsigned SignOpc = IsFloor ? ISD::AND : ISD::OR;
9240+
LHS = DAG.getFreeze(LHS);
9241+
RHS = DAG.getFreeze(RHS);
9242+
SDValue Sign = DAG.getNode(SignOpc, dl, VT, LHS, RHS);
9243+
SDValue Xor = DAG.getNode(ISD::XOR, dl, VT, LHS, RHS);
9244+
SDValue Shift =
9245+
DAG.getNode(ShiftOpc, dl, VT, Xor, DAG.getShiftAmountConstant(1, VT, dl));
9246+
return DAG.getNode(SumOpc, dl, VT, Sign, Shift);
9247+
}
9248+
92039249
SDValue TargetLowering::expandBSWAP(SDNode *N, SelectionDAG &DAG) const {
92049250
SDLoc dl(N);
92059251
EVT VT = N->getValueType(0);

0 commit comments

Comments
 (0)