Skip to content

[DAG] Add legalization handling for AVGCEIL/AVGFLOOR nodes #92096

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -5345,6 +5345,11 @@ class TargetLowering : public TargetLoweringBase {
/// \returns The expansion result or SDValue() if it fails.
SDValue expandABD(SDNode *N, SelectionDAG &DAG) const;

/// Expand vector/scalar AVGCEILS/AVGCEILU/AVGFLOORS/AVGFLOORU nodes.
/// \param N Node to expand
/// \returns The expansion result or SDValue() if it fails.
SDValue expandAVG(SDNode *N, SelectionDAG &DAG) const;

/// Expand BSWAP nodes. Expands scalar/vector BSWAP nodes with i16/i32/i64
/// scalar types. Returns SDValue() if expand fails.
/// \param N Node to expand
Expand Down
24 changes: 20 additions & 4 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2575,13 +2575,13 @@ SDValue DAGCombiner::foldSubToAvg(SDNode *N, const SDLoc &DL) {
EVT VT = N0.getValueType();
SDValue A, B;

if (hasOperation(ISD::AVGCEILU, VT) &&
if ((!LegalOperations || hasOperation(ISD::AVGCEILU, VT)) &&
sd_match(N, m_Sub(m_Or(m_Value(A), m_Value(B)),
m_Srl(m_Xor(m_Deferred(A), m_Deferred(B)),
m_SpecificInt(1))))) {
return DAG.getNode(ISD::AVGCEILU, DL, VT, A, B);
}
if (hasOperation(ISD::AVGCEILS, VT) &&
if ((!LegalOperations || hasOperation(ISD::AVGCEILS, VT)) &&
sd_match(N, m_Sub(m_Or(m_Value(A), m_Value(B)),
m_Sra(m_Xor(m_Deferred(A), m_Deferred(B)),
m_SpecificInt(1))))) {
Expand Down Expand Up @@ -2947,13 +2947,13 @@ SDValue DAGCombiner::foldAddToAvg(SDNode *N, const SDLoc &DL) {
EVT VT = N0.getValueType();
SDValue A, B;

if (hasOperation(ISD::AVGFLOORU, VT) &&
if ((!LegalOperations || hasOperation(ISD::AVGFLOORU, VT)) &&
sd_match(N, m_Add(m_And(m_Value(A), m_Value(B)),
m_Srl(m_Xor(m_Deferred(A), m_Deferred(B)),
m_SpecificInt(1))))) {
return DAG.getNode(ISD::AVGFLOORU, DL, VT, A, B);
}
if (hasOperation(ISD::AVGFLOORS, VT) &&
if ((!LegalOperations || hasOperation(ISD::AVGFLOORS, VT)) &&
sd_match(N, m_Add(m_And(m_Value(A), m_Value(B)),
m_Sra(m_Xor(m_Deferred(A), m_Deferred(B)),
m_SpecificInt(1))))) {
Expand Down Expand Up @@ -5253,6 +5253,22 @@ SDValue DAGCombiner::visitAVG(SDNode *N) {
SDValue AvgCeilU = DAG.getNode(ISD::AVGCEILU, DL, A.getValueType(), A, B);
return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, AvgCeilU);
}

// Fold avgflooru(x,y) -> avgceilu(x,y-1) iff y != 0
// Fold avgflooru(x,y) -> avgceilu(x-1,y) iff x != 0
// Check if avgflooru isn't legal/custom but avgceilu is.
if (Opcode == ISD::AVGFLOORU && !hasOperation(ISD::AVGFLOORU, VT) &&
(!LegalOperations || hasOperation(ISD::AVGCEILU, VT))) {
if (DAG.isKnownNeverZero(N1))
return DAG.getNode(
ISD::AVGCEILU, DL, VT, N0,
DAG.getNode(ISD::ADD, DL, VT, N1, DAG.getAllOnesConstant(DL, VT)));
if (DAG.isKnownNeverZero(N0))
return DAG.getNode(
ISD::AVGCEILU, DL, VT, N1,
DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getAllOnesConstant(DL, VT)));
}

return SDValue();
}

Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3059,6 +3059,13 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
if ((Tmp1 = TLI.expandABD(Node, DAG)))
Results.push_back(Tmp1);
break;
case ISD::AVGCEILS:
case ISD::AVGCEILU:
case ISD::AVGFLOORS:
case ISD::AVGFLOORU:
if ((Tmp1 = TLI.expandAVG(Node, DAG)))
Results.push_back(Tmp1);
break;
case ISD::CTPOP:
if ((Tmp1 = TLI.expandCTPOP(Node, DAG)))
Results.push_back(Tmp1);
Expand Down
14 changes: 14 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,17 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
case ISD::VP_SUB:
case ISD::VP_MUL: Res = PromoteIntRes_SimpleIntBinOp(N); break;

case ISD::AVGCEILS:
case ISD::AVGFLOORS:
case ISD::VP_SMIN:
case ISD::VP_SMAX:
case ISD::SDIV:
case ISD::SREM:
case ISD::VP_SDIV:
case ISD::VP_SREM: Res = PromoteIntRes_SExtIntBinOp(N); break;

case ISD::AVGCEILU:
case ISD::AVGFLOORU:
case ISD::VP_UMIN:
case ISD::VP_UMAX:
case ISD::UDIV:
Expand Down Expand Up @@ -2818,6 +2822,11 @@ void DAGTypeLegalizer::ExpandIntegerResult(SDNode *N, unsigned ResNo) {
case ISD::SSHLSAT:
case ISD::USHLSAT: ExpandIntRes_SHLSAT(N, Lo, Hi); break;

case ISD::AVGCEILS:
case ISD::AVGCEILU:
case ISD::AVGFLOORS:
case ISD::AVGFLOORU: ExpandIntRes_AVG(N, Lo, Hi); break;

case ISD::SMULFIX:
case ISD::SMULFIXSAT:
case ISD::UMULFIX:
Expand Down Expand Up @@ -4120,6 +4129,11 @@ void DAGTypeLegalizer::ExpandIntRes_READCOUNTER(SDNode *N, SDValue &Lo,
ReplaceValueWith(SDValue(N, 1), R.getValue(2));
}

void DAGTypeLegalizer::ExpandIntRes_AVG(SDNode *N, SDValue &Lo, SDValue &Hi) {
SDValue Result = TLI.expandAVG(N, DAG);
SplitInteger(Result, Lo, Hi);
}

void DAGTypeLegalizer::ExpandIntRes_ADDSUBSAT(SDNode *N, SDValue &Lo,
SDValue &Hi) {
SDValue Result = TLI.expandAddSubSat(N, DAG);
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
void ExpandIntRes_SADDSUBO (SDNode *N, SDValue &Lo, SDValue &Hi);
void ExpandIntRes_UADDSUBO (SDNode *N, SDValue &Lo, SDValue &Hi);
void ExpandIntRes_XMULO (SDNode *N, SDValue &Lo, SDValue &Hi);
void ExpandIntRes_AVG (SDNode *N, SDValue &Lo, SDValue &Hi);
void ExpandIntRes_ADDSUBSAT (SDNode *N, SDValue &Lo, SDValue &Hi);
void ExpandIntRes_SHLSAT (SDNode *N, SDValue &Lo, SDValue &Hi);
void ExpandIntRes_MULFIX (SDNode *N, SDValue &Lo, SDValue &Hi);
Expand Down
13 changes: 13 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,10 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
case ISD::ABS:
case ISD::ABDS:
case ISD::ABDU:
case ISD::AVGCEILS:
case ISD::AVGCEILU:
case ISD::AVGFLOORS:
case ISD::AVGFLOORU:
case ISD::BSWAP:
case ISD::BITREVERSE:
case ISD::CTLZ:
Expand Down Expand Up @@ -918,6 +922,15 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
return;
}
break;
case ISD::AVGCEILS:
case ISD::AVGCEILU:
case ISD::AVGFLOORS:
case ISD::AVGFLOORU:
if (SDValue Expanded = TLI.expandAVG(Node, DAG)) {
Results.push_back(Expanded);
return;
}
break;
case ISD::BITREVERSE:
ExpandBITREVERSE(Node, Results);
return;
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) {
break;
case ISD::ADD:
case ISD::AND:
case ISD::AVGCEILS:
case ISD::AVGCEILU:
case ISD::AVGFLOORS:
case ISD::AVGFLOORU:
case ISD::FADD:
case ISD::FCOPYSIGN:
case ISD::FDIV:
Expand Down Expand Up @@ -1173,6 +1177,10 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
case ISD::MUL: case ISD::VP_MUL:
case ISD::MULHS:
case ISD::MULHU:
case ISD::AVGCEILS:
case ISD::AVGCEILU:
case ISD::AVGFLOORS:
case ISD::AVGFLOORU:
case ISD::FADD: case ISD::VP_FADD:
case ISD::FSUB: case ISD::VP_FSUB:
case ISD::FMUL: case ISD::VP_FMUL:
Expand Down
58 changes: 51 additions & 7 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -951,11 +951,11 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedVectorElts(

// Attempt to form ext(avgfloor(A, B)) from shr(add(ext(A), ext(B)), 1).
// or to form ext(avgceil(A, B)) from shr(add(ext(A), ext(B), 1), 1).
static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
static SDValue combineShiftToAVG(SDValue Op,
TargetLowering::TargetLoweringOpt &TLO,
const TargetLowering &TLI,
const APInt &DemandedBits,
const APInt &DemandedElts,
unsigned Depth) {
const APInt &DemandedElts, unsigned Depth) {
assert((Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SRA) &&
"SRL or SRA node is required here!");
// Is the right shift using an immediate value of 1?
Expand Down Expand Up @@ -1006,6 +1006,7 @@ static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
// If the shift is unsigned (srl):
// - Needs >= 1 zero bit for both operands.
// - Needs 1 demanded bit zero and >= 2 sign bits.
SelectionDAG &DAG = TLO.DAG;
unsigned ShiftOpc = Op.getOpcode();
bool IsSigned = false;
unsigned KnownBits;
Expand Down Expand Up @@ -1061,10 +1062,10 @@ static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
EVT NVT = EVT::getIntegerVT(*DAG.getContext(), llvm::bit_ceil(MinWidth));
if (VT.isVector())
NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
if (!TLI.isOperationLegalOrCustom(AVGOpc, NVT)) {
if (TLO.LegalOperations() && !TLI.isOperationLegal(AVGOpc, NVT)) {
// If we could not transform, and (both) adds are nuw/nsw, we can use the
// larger type size to do the transform.
if (!TLI.isOperationLegalOrCustom(AVGOpc, VT))
if (TLO.LegalOperations() && !TLI.isOperationLegal(AVGOpc, VT))
return SDValue();
if (DAG.willNotOverflowAdd(IsSigned, Add.getOperand(0),
Add.getOperand(1)) &&
Expand Down Expand Up @@ -2002,7 +2003,7 @@ bool TargetLowering::SimplifyDemandedBits(
}

// Try to match AVG patterns (after shift simplification).
if (SDValue AVG = combineShiftToAVG(Op, TLO.DAG, *this, DemandedBits,
if (SDValue AVG = combineShiftToAVG(Op, TLO, *this, DemandedBits,
DemandedElts, Depth + 1))
return TLO.CombineTo(Op, AVG);

Expand Down Expand Up @@ -2113,7 +2114,7 @@ bool TargetLowering::SimplifyDemandedBits(
}

// Try to match AVG patterns (after shift simplification).
if (SDValue AVG = combineShiftToAVG(Op, TLO.DAG, *this, DemandedBits,
if (SDValue AVG = combineShiftToAVG(Op, TLO, *this, DemandedBits,
DemandedElts, Depth + 1))
return TLO.CombineTo(Op, AVG);

Expand Down Expand Up @@ -9225,6 +9226,49 @@ SDValue TargetLowering::expandABD(SDNode *N, SelectionDAG &DAG) const {
DAG.getNode(ISD::SUB, dl, VT, RHS, LHS));
}

SDValue TargetLowering::expandAVG(SDNode *N, SelectionDAG &DAG) const {
SDLoc dl(N);
EVT VT = N->getValueType(0);
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);

unsigned Opc = N->getOpcode();
bool IsFloor = Opc == ISD::AVGFLOORS || Opc == ISD::AVGFLOORU;
bool IsSigned = Opc == ISD::AVGCEILS || Opc == ISD::AVGFLOORS;
unsigned ShiftOpc = IsSigned ? ISD::SRA : ISD::SRL;
assert((Opc == ISD::AVGFLOORS || Opc == ISD::AVGCEILS ||
Opc == ISD::AVGFLOORU || Opc == ISD::AVGCEILU) &&
"Unknown AVG node");

// If the operands are already extended, we can add+shift.
bool IsExt =
(IsSigned && DAG.ComputeNumSignBits(LHS) >= 2 &&
DAG.ComputeNumSignBits(RHS) >= 2) ||
(!IsSigned && DAG.computeKnownBits(LHS).countMinLeadingZeros() >= 1 &&
DAG.computeKnownBits(RHS).countMinLeadingZeros() >= 1);
if (IsExt) {
SDValue Sum = DAG.getNode(ISD::ADD, dl, VT, LHS, RHS);
if (!IsFloor)
Sum = DAG.getNode(ISD::ADD, dl, VT, Sum, DAG.getConstant(1, dl, VT));
return DAG.getNode(ShiftOpc, dl, VT, Sum,
DAG.getShiftAmountConstant(1, VT, dl));
}

// avgceils(lhs, rhs) -> sub(or(lhs,rhs),ashr(xor(lhs,rhs),1))
// avgceilu(lhs, rhs) -> sub(or(lhs,rhs),lshr(xor(lhs,rhs),1))
// avgfloors(lhs, rhs) -> add(and(lhs,rhs),ashr(xor(lhs,rhs),1))
// avgflooru(lhs, rhs) -> add(and(lhs,rhs),lshr(xor(lhs,rhs),1))
unsigned SumOpc = IsFloor ? ISD::ADD : ISD::SUB;
unsigned SignOpc = IsFloor ? ISD::AND : ISD::OR;
LHS = DAG.getFreeze(LHS);
RHS = DAG.getFreeze(RHS);
SDValue Sign = DAG.getNode(SignOpc, dl, VT, LHS, RHS);
SDValue Xor = DAG.getNode(ISD::XOR, dl, VT, LHS, RHS);
SDValue Shift =
DAG.getNode(ShiftOpc, dl, VT, Xor, DAG.getShiftAmountConstant(1, VT, dl));
return DAG.getNode(SumOpc, dl, VT, Sign, Shift);
}

SDValue TargetLowering::expandBSWAP(SDNode *N, SelectionDAG &DAG) const {
SDLoc dl(N);
EVT VT = N->getValueType(0);
Expand Down
Loading
Loading