Skip to content

Commit e3baff2

Browse files
committed
[DAG] Support saturated truncate
`truncate` is `saturated` if no additional conversion is required between the target and return values. if the target is `saturated` when attempting to crop from a `vector`, there is an opportunity to optimize it. previously, each architecture had an attemping optimization, so there was redundant code. this patch implements common logic by adding `ISD::TRUNCATE_[US]SAT` to indicate saturated truncate. Fixes #85903
1 parent f8006a5 commit e3baff2

File tree

5 files changed

+182
-1
lines changed

5 files changed

+182
-1
lines changed

llvm/include/llvm/CodeGen/ISDOpcodes.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,13 @@ enum NodeType {
814814

815815
/// TRUNCATE - Completely drop the high bits.
816816
TRUNCATE,
817+
/// TRUNCATE_[SU]SAT - Truncate for saturated operand
818+
TRUNCATE_SSAT_S, // saturate signed input to signed result -
819+
// truncate(smin(smax(x)))
820+
TRUNCATE_SSAT_U, // saturate signed input to unsigned result -
821+
// truncate(smin(smax(x,0)))
822+
TRUNCATE_USAT_U, // saturate unsigned input to unsigned result -
823+
// truncate(umin(x))
817824

818825
/// [SU]INT_TO_FP - These operators convert integers (whose interpreted sign
819826
/// depends on the first letter) to floating point.

llvm/include/llvm/Target/TargetSelectionDAG.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,9 @@ def sext : SDNode<"ISD::SIGN_EXTEND", SDTIntExtendOp>;
477477
def zext : SDNode<"ISD::ZERO_EXTEND", SDTIntExtendOp>;
478478
def anyext : SDNode<"ISD::ANY_EXTEND" , SDTIntExtendOp>;
479479
def trunc : SDNode<"ISD::TRUNCATE" , SDTIntTruncOp>;
480+
def truncssat_s : SDNode<"ISD::TRUNCATE_SSAT_S", SDTIntTruncOp>;
481+
def truncssat_u : SDNode<"ISD::TRUNCATE_SSAT_U", SDTIntTruncOp>;
482+
def truncusat_u : SDNode<"ISD::TRUNCATE_USAT_U", SDTIntTruncOp>;
480483
def bitconvert : SDNode<"ISD::BITCAST" , SDTUnaryOp>;
481484
def addrspacecast : SDNode<"ISD::ADDRSPACECAST", SDTUnaryOp>;
482485
def freeze : SDNode<"ISD::FREEZE" , SDTFreeze>;

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 164 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,7 @@ namespace {
486486
SDValue visitSIGN_EXTEND_INREG(SDNode *N);
487487
SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
488488
SDValue visitTRUNCATE(SDNode *N);
489+
SDValue visitTRUNCATE_USAT(SDNode *N);
489490
SDValue visitBITCAST(SDNode *N);
490491
SDValue visitFREEZE(SDNode *N);
491492
SDValue visitBUILD_PAIR(SDNode *N);
@@ -1908,6 +1909,8 @@ SDValue DAGCombiner::visit(SDNode *N) {
19081909
case ISD::ZERO_EXTEND_VECTOR_INREG:
19091910
case ISD::ANY_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N);
19101911
case ISD::TRUNCATE: return visitTRUNCATE(N);
1912+
case ISD::TRUNCATE_USAT_U:
1913+
case ISD::TRUNCATE_SSAT_U: return visitTRUNCATE_USAT(N);
19111914
case ISD::BITCAST: return visitBITCAST(N);
19121915
case ISD::BUILD_PAIR: return visitBUILD_PAIR(N);
19131916
case ISD::FADD: return visitFADD(N);
@@ -13203,7 +13206,9 @@ SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
1320313206
unsigned CastOpcode = Cast->getOpcode();
1320413207
assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
1320513208
CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
13206-
CastOpcode == ISD::FP_ROUND) &&
13209+
CastOpcode == ISD::TRUNCATE_SSAT_S ||
13210+
CastOpcode == ISD::TRUNCATE_SSAT_U ||
13211+
CastOpcode == ISD::TRUNCATE_USAT_U || CastOpcode == ISD::FP_ROUND) &&
1320713212
"Unexpected opcode for vector select narrowing/widening");
1320813213

1320913214
// We only do this transform before legal ops because the pattern may be
@@ -14915,6 +14920,159 @@ SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
1491514920
return SDValue();
1491614921
}
1491714922

14923+
SDValue DAGCombiner::visitTRUNCATE_USAT(SDNode *N) {
14924+
EVT VT = N->getValueType(0);
14925+
SDValue N0 = N->getOperand(0);
14926+
SDValue FPInstr = N0.getOpcode() == ISD::SMAX ? N0.getOperand(0) : N0;
14927+
if (FPInstr.getOpcode() == ISD::FP_TO_SINT ||
14928+
FPInstr.getOpcode() == ISD::FP_TO_UINT) {
14929+
EVT FPVT = FPInstr.getOperand(0).getValueType();
14930+
if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(ISD::FP_TO_UINT_SAT,
14931+
FPVT, VT))
14932+
return SDValue();
14933+
SDValue Sat = DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(FPInstr), VT,
14934+
FPInstr.getOperand(0),
14935+
DAG.getValueType(VT.getScalarType()));
14936+
return Sat;
14937+
}
14938+
14939+
return SDValue();
14940+
}
14941+
14942+
/// Detect patterns of truncation with unsigned saturation:
14943+
///
14944+
/// (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type).
14945+
/// Return the source value x to be truncated or SDValue() if the pattern was
14946+
/// not matched.
14947+
///
14948+
static SDValue detectUSatUPattern(SDValue In, EVT VT) {
14949+
EVT InVT = In.getValueType();
14950+
14951+
// Saturation with truncation. We truncate from InVT to VT.
14952+
assert(InVT.getScalarSizeInBits() > VT.getScalarSizeInBits() &&
14953+
"Unexpected types for truncate operation");
14954+
14955+
// Match min/max and return limit value as a parameter.
14956+
auto MatchMinMax = [](SDValue V, unsigned Opcode, APInt &Limit) -> SDValue {
14957+
if (V.getOpcode() == Opcode &&
14958+
ISD::isConstantSplatVector(V.getOperand(1).getNode(), Limit))
14959+
return V.getOperand(0);
14960+
return SDValue();
14961+
};
14962+
14963+
APInt C1, C2;
14964+
if (SDValue UMin = MatchMinMax(In, ISD::UMIN, C2))
14965+
// C2 should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according
14966+
// the element size of the destination type.
14967+
if (C2.isMask(VT.getScalarSizeInBits()))
14968+
return UMin;
14969+
14970+
return SDValue();
14971+
}
14972+
14973+
/// Detect patterns of truncation with signed saturation:
14974+
/// (truncate (smin ((smax (x, signed_min_of_dest_type)),
14975+
/// signed_max_of_dest_type)) to dest_type)
14976+
/// or:
14977+
/// (truncate (smax ((smin (x, signed_max_of_dest_type)),
14978+
/// signed_min_of_dest_type)) to dest_type).
14979+
/// With MatchPackUS, the smax/smin range is [0, unsigned_max_of_dest_type].
14980+
/// Return the source value to be truncated or SDValue() if the pattern was not
14981+
/// matched.
14982+
static SDValue detectSSatSPattern(SDValue In, EVT VT) {
14983+
unsigned NumDstBits = VT.getScalarSizeInBits();
14984+
unsigned NumSrcBits = In.getScalarValueSizeInBits();
14985+
assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
14986+
14987+
auto MatchMinMax = [](SDValue V, unsigned Opcode,
14988+
const APInt &Limit) -> SDValue {
14989+
APInt C;
14990+
if (V.getOpcode() == Opcode &&
14991+
ISD::isConstantSplatVector(V.getOperand(1).getNode(), C) && C == Limit)
14992+
return V.getOperand(0);
14993+
return SDValue();
14994+
};
14995+
14996+
APInt SignedMax, SignedMin;
14997+
SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
14998+
SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);
14999+
if (SDValue SMin = MatchMinMax(In, ISD::SMIN, SignedMax)) {
15000+
if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, SignedMin)) {
15001+
return SMax;
15002+
}
15003+
}
15004+
if (SDValue SMax = MatchMinMax(In, ISD::SMAX, SignedMin)) {
15005+
if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, SignedMax)) {
15006+
return SMin;
15007+
}
15008+
}
15009+
return SDValue();
15010+
}
15011+
15012+
/// Detect patterns of truncation with unsigned saturation:
15013+
///
15014+
/// (truncate (smin (smax (x, C1), C2)) to dest_type),
15015+
/// where C1 >= 0 and C2 is unsigned max of destination type.
15016+
///
15017+
/// (truncate (smax (smin (x, C2), C1)) to dest_type)
15018+
/// where C1 >= 0, C2 is unsigned max of destination type and C1 <= C2.
15019+
///
15020+
static SDValue detectSSatUPattern(SDValue In, EVT VT, SelectionDAG &DAG,
15021+
const SDLoc &DL) {
15022+
EVT InVT = In.getValueType();
15023+
15024+
// Saturation with truncation. We truncate from InVT to VT.
15025+
assert(InVT.getScalarSizeInBits() > VT.getScalarSizeInBits() &&
15026+
"Unexpected types for truncate operation");
15027+
15028+
// Match min/max and return limit value as a parameter.
15029+
auto MatchMinMax = [](SDValue V, unsigned Opcode, APInt &Limit) -> SDValue {
15030+
if (V.getOpcode() == Opcode &&
15031+
ISD::isConstantSplatVector(V.getOperand(1).getNode(), Limit))
15032+
return V.getOperand(0);
15033+
return SDValue();
15034+
};
15035+
15036+
APInt C1, C2;
15037+
if (SDValue SMin = MatchMinMax(In, ISD::SMIN, C2))
15038+
if (MatchMinMax(SMin, ISD::SMAX, C1))
15039+
if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()))
15040+
return SMin;
15041+
15042+
if (SDValue SMax = MatchMinMax(In, ISD::SMAX, C1))
15043+
if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, C2))
15044+
if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()) &&
15045+
C2.uge(C1))
15046+
return DAG.getNode(ISD::SMAX, DL, InVT, SMin, In.getOperand(1));
15047+
15048+
return SDValue();
15049+
}
15050+
15051+
static SDValue foldToSaturated(SDNode *N, EVT &VT, SDValue &Src, EVT &SrcVT,
15052+
SDLoc &DL, const TargetLowering &TLI,
15053+
SelectionDAG &DAG) {
15054+
if (Src.getOpcode() == ISD::SMIN || Src.getOpcode() == ISD::SMAX) {
15055+
if (TLI.isOperationLegalOrCustom(ISD::TRUNCATE_SSAT_S, SrcVT) &&
15056+
TLI.isTypeDesirableForOp(ISD::TRUNCATE_SSAT_S, VT)) {
15057+
if (SDValue SSatVal = detectSSatSPattern(Src, VT))
15058+
return DAG.getNode(ISD::TRUNCATE_SSAT_S, DL, VT, SSatVal);
15059+
} else if (TLI.isOperationLegalOrCustom(ISD::TRUNCATE_SSAT_U, SrcVT) &&
15060+
TLI.isTypeDesirableForOp(ISD::TRUNCATE_SSAT_U, VT)) {
15061+
if (SDValue SSatVal = detectSSatUPattern(Src, VT, DAG, DL))
15062+
return DAG.getNode(ISD::TRUNCATE_SSAT_S, DL, VT, SSatVal);
15063+
}
15064+
} else if (Src.getOpcode() == ISD::UMIN) {
15065+
if (TLI.isOperationLegalOrCustom(ISD::TRUNCATE_USAT_U, SrcVT) &&
15066+
TLI.isTypeDesirableForOp(ISD::TRUNCATE_USAT_U, VT)) {
15067+
if (SDValue USatVal = detectUSatUPattern(Src, VT)) {
15068+
return DAG.getNode(ISD::TRUNCATE_USAT_U, DL, VT, USatVal);
15069+
}
15070+
}
15071+
}
15072+
15073+
return SDValue();
15074+
}
15075+
1491815076
SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
1491915077
SDValue N0 = N->getOperand(0);
1492015078
EVT VT = N->getValueType(0);
@@ -14930,6 +15088,11 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
1493015088
if (N0.getOpcode() == ISD::TRUNCATE)
1493115089
return DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
1493215090

15091+
// fold satruated truncate
15092+
if (SDValue SaturatedTR = foldToSaturated(N, VT, N0, SrcVT, DL, TLI, DAG)) {
15093+
return SaturatedTR;
15094+
}
15095+
1493315096
// fold (truncate c1) -> c1
1493415097
if (SDValue C = DAG.FoldConstantArithmetic(ISD::TRUNCATE, DL, VT, {N0}))
1493515098
return C;

llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,9 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
380380
case ISD::SIGN_EXTEND_VECTOR_INREG: return "sign_extend_vector_inreg";
381381
case ISD::ZERO_EXTEND_VECTOR_INREG: return "zero_extend_vector_inreg";
382382
case ISD::TRUNCATE: return "truncate";
383+
case ISD::TRUNCATE_SSAT_S: return "truncate_ssat_s";
384+
case ISD::TRUNCATE_SSAT_U: return "truncate_ssat_u";
385+
case ISD::TRUNCATE_USAT_U: return "truncate_usat_u";
383386
case ISD::FP_ROUND: return "fp_round";
384387
case ISD::STRICT_FP_ROUND: return "strict_fp_round";
385388
case ISD::FP_EXTEND: return "fp_extend";

llvm/lib/CodeGen/TargetLoweringBase.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,11 @@ void TargetLoweringBase::initActions() {
753753
// Absolute difference
754754
setOperationAction({ISD::ABDS, ISD::ABDU}, VT, Expand);
755755

756+
// Saturated trunc
757+
setOperationAction(ISD::TRUNCATE_SSAT_S, VT, Expand);
758+
setOperationAction(ISD::TRUNCATE_SSAT_U, VT, Expand);
759+
setOperationAction(ISD::TRUNCATE_USAT_U, VT, Expand);
760+
756761
// These default to Expand so they will be expanded to CTLZ/CTTZ by default.
757762
setOperationAction({ISD::CTLZ_ZERO_UNDEF, ISD::CTTZ_ZERO_UNDEF}, VT,
758763
Expand);

0 commit comments

Comments
 (0)