Skip to content

Commit 8d81896

Browse files
committed
[DAG] Support saturated truncate
A truncate is considered saturated if no additional conversion is required between the target and return values. If the target is saturated when attempting to truncate from a vector, there is an opportunity to optimize it. Previously, each architecture had its own attempt at optimization, leading to redundant code. This patch implements common logic by introducing three new ISDs: `ISD::TRUNCATE_SSAT_S`: When the operand is a signed value and the range of values matches the range of signed values of the destination type. `ISD::TRUNCATE_SSAT_U`: When the operand is a signed value and the range of values matches the range of unsigned values of the destination type. `ISD::TRUNCATE_USAT_U`: When the operand is an unsigned value and the range of values matches the range of unsigned values of the destination type. These ISDs indicate a saturated truncate. Fixes #85903
1 parent f8006a5 commit 8d81896

File tree

5 files changed

+142
-1
lines changed

5 files changed

+142
-1
lines changed

llvm/include/llvm/CodeGen/ISDOpcodes.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,26 @@ enum NodeType {
814814

815815
/// TRUNCATE - Completely drop the high bits.
816816
TRUNCATE,
817+
/// TRUNCATE_[SU]SAT_[SU] - Truncate for saturated operand
818+
/// [SU] located in middle, prefix for `SAT` means indicates whether
819+
/// existing truncate target was a signed operation. For examples,
820+
/// If `truncate(smin(smax(x, C), C))` was saturated then become `S`.
821+
/// If `truncate(umin(x, C))` was saturated then become `U`.
822+
/// [SU] located in last indicates whether range of truncated values is
823+
/// sign-saturated. For example, if `truncate(smin(smax(x, C), C))` is a
824+
/// truncation to `i8`, then if value of C ranges from `-128 to 127`, it will
825+
/// be saturated against signed values, resulting in `S`, which will combine
826+
/// to `TRUNCATE_SSAT_S`. If the value of C ranges from `0 to 255`, it will
827+
/// be saturated against unsigned values, resulting in `U`, which will
828+
/// combine to `TRUNATE_SSAT_U`. Similarly, in `truncate(umin(x, C))`, if
829+
/// value of C ranges from `0 to 255`, it becomes `U` because it is saturated
830+
/// for unsigned values. As a result, it combines to `TRUNCATE_USAT_U`.
831+
TRUNCATE_SSAT_S, // saturate signed input to signed result -
832+
// truncate(smin(smax(x, C), C))
833+
TRUNCATE_SSAT_U, // saturate signed input to unsigned result -
834+
// truncate(smin(smax(x, 0), C))
835+
TRUNCATE_USAT_U, // saturate unsigned input to unsigned result -
836+
// truncate(umin(x, C))
817837

818838
/// [SU]INT_TO_FP - These operators convert integers (whose interpreted sign
819839
/// depends on the first letter) to floating point.

llvm/include/llvm/Target/TargetSelectionDAG.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,9 @@ def sext : SDNode<"ISD::SIGN_EXTEND", SDTIntExtendOp>;
477477
def zext : SDNode<"ISD::ZERO_EXTEND", SDTIntExtendOp>;
478478
def anyext : SDNode<"ISD::ANY_EXTEND" , SDTIntExtendOp>;
479479
def trunc : SDNode<"ISD::TRUNCATE" , SDTIntTruncOp>;
480+
def truncssat_s : SDNode<"ISD::TRUNCATE_SSAT_S", SDTIntTruncOp>;
481+
def truncssat_u : SDNode<"ISD::TRUNCATE_SSAT_U", SDTIntTruncOp>;
482+
def truncusat_u : SDNode<"ISD::TRUNCATE_USAT_U", SDTIntTruncOp>;
480483
def bitconvert : SDNode<"ISD::BITCAST" , SDTUnaryOp>;
481484
def addrspacecast : SDNode<"ISD::ADDRSPACECAST", SDTUnaryOp>;
482485
def freeze : SDNode<"ISD::FREEZE" , SDTFreeze>;

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,7 @@ namespace {
486486
SDValue visitSIGN_EXTEND_INREG(SDNode *N);
487487
SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
488488
SDValue visitTRUNCATE(SDNode *N);
489+
SDValue visitTRUNCATE_USAT(SDNode *N);
489490
SDValue visitBITCAST(SDNode *N);
490491
SDValue visitFREEZE(SDNode *N);
491492
SDValue visitBUILD_PAIR(SDNode *N);
@@ -13203,7 +13204,9 @@ SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
1320313204
unsigned CastOpcode = Cast->getOpcode();
1320413205
assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
1320513206
CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
13206-
CastOpcode == ISD::FP_ROUND) &&
13207+
CastOpcode == ISD::TRUNCATE_SSAT_S ||
13208+
CastOpcode == ISD::TRUNCATE_SSAT_U ||
13209+
CastOpcode == ISD::TRUNCATE_USAT_U || CastOpcode == ISD::FP_ROUND) &&
1320713210
"Unexpected opcode for vector select narrowing/widening");
1320813211

1320913212
// We only do this transform before legal ops because the pattern may be
@@ -14915,6 +14918,109 @@ SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
1491514918
return SDValue();
1491614919
}
1491714920

14921+
/// Detect patterns of truncation with unsigned saturation:
14922+
///
14923+
/// (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type).
14924+
/// Return the source value x to be truncated or SDValue() if the pattern was
14925+
/// not matched.
14926+
///
14927+
static SDValue detectUSatUPattern(SDValue In, EVT VT) {
14928+
unsigned NumDstBits = VT.getScalarSizeInBits();
14929+
unsigned NumSrcBits = In.getScalarValueSizeInBits();
14930+
// Saturation with truncation. We truncate from InVT to VT.
14931+
assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
14932+
14933+
SDValue Min;
14934+
APInt UnsignedMax = APInt::getMaxValue(NumDstBits).zext(NumSrcBits);
14935+
if (sd_match(In, m_UMin(m_Value(Min), m_SpecificInt(UnsignedMax))))
14936+
return Min;
14937+
14938+
return SDValue();
14939+
}
14940+
14941+
/// Detect patterns of truncation with signed saturation:
14942+
/// (truncate (smin (smax (x, signed_min_of_dest_type),
14943+
/// signed_max_of_dest_type)) to dest_type)
14944+
/// or:
14945+
/// (truncate (smax (smin (x, signed_max_of_dest_type),
14946+
/// signed_min_of_dest_type)) to dest_type).
14947+
///
14948+
/// Return the source value to be truncated or SDValue() if the pattern was not
14949+
/// matched.
14950+
static SDValue detectSSatSPattern(SDValue In, EVT VT) {
14951+
unsigned NumDstBits = VT.getScalarSizeInBits();
14952+
unsigned NumSrcBits = In.getScalarValueSizeInBits();
14953+
// Saturation with truncation. We truncate from InVT to VT.
14954+
assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
14955+
14956+
SDValue Val;
14957+
APInt SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
14958+
APInt SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);
14959+
14960+
if (sd_match(In, m_SMin(m_SMax(m_Value(Val), m_SpecificInt(SignedMin)),
14961+
m_SpecificInt(SignedMax))))
14962+
return Val;
14963+
14964+
if (sd_match(In, m_SMax(m_SMin(m_Value(Val), m_SpecificInt(SignedMax)),
14965+
m_SpecificInt(SignedMin))))
14966+
return Val;
14967+
14968+
return SDValue();
14969+
}
14970+
14971+
/// Detect patterns of truncation with unsigned saturation:
14972+
static SDValue detectSSatUPattern(SDValue In, EVT VT, SelectionDAG &DAG,
14973+
const SDLoc &DL) {
14974+
unsigned NumDstBits = VT.getScalarSizeInBits();
14975+
unsigned NumSrcBits = In.getScalarValueSizeInBits();
14976+
// Saturation with truncation. We truncate from InVT to VT.
14977+
assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
14978+
14979+
SDValue Val;
14980+
APInt UnsignedMax = APInt::getMaxValue(NumDstBits).zext(NumSrcBits);
14981+
// Min == 0, Max is unsigned max of destination type.
14982+
if (sd_match(In, m_SMax(m_SMin(m_Value(Val), m_SpecificInt(UnsignedMax)),
14983+
m_Zero())))
14984+
return Val;
14985+
14986+
if (sd_match(In, m_SMin(m_SMax(m_Value(Val), m_Zero()),
14987+
m_SpecificInt(UnsignedMax))))
14988+
return Val;
14989+
14990+
if (sd_match(In, m_UMin(m_SMax(m_Value(Val), m_Zero()),
14991+
m_SpecificInt(UnsignedMax))))
14992+
return Val;
14993+
14994+
return SDValue();
14995+
}
14996+
14997+
static SDValue foldToSaturated(SDNode *N, EVT &VT, SDValue &Src, EVT &SrcVT,
14998+
SDLoc &DL, const TargetLowering &TLI,
14999+
SelectionDAG &DAG) {
15000+
auto AllowedTruncateSat = [&](unsigned Opc, EVT SrcVT, EVT VT) -> bool {
15001+
return (TLI.isOperationLegalOrCustom(Opc, SrcVT) &&
15002+
TLI.isTypeDesirableForOp(Opc, VT));
15003+
};
15004+
15005+
if (Src.getOpcode() == ISD::SMIN || Src.getOpcode() == ISD::SMAX) {
15006+
if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_S, SrcVT, VT))
15007+
if (SDValue SSatVal = detectSSatSPattern(Src, VT))
15008+
return DAG.getNode(ISD::TRUNCATE_SSAT_S, DL, VT, SSatVal);
15009+
if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_U, SrcVT, VT))
15010+
if (SDValue SSatVal = detectSSatUPattern(Src, VT, DAG, DL))
15011+
return DAG.getNode(ISD::TRUNCATE_SSAT_U, DL, VT, SSatVal);
15012+
} else if (Src.getOpcode() == ISD::UMIN) {
15013+
if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_U, SrcVT, VT))
15014+
if (SDValue SSatVal = detectSSatUPattern(Src, VT, DAG, DL))
15015+
return DAG.getNode(ISD::TRUNCATE_SSAT_U, DL, VT, SSatVal);
15016+
if (AllowedTruncateSat(ISD::TRUNCATE_USAT_U, SrcVT, VT))
15017+
if (SDValue USatVal = detectUSatUPattern(Src, VT))
15018+
return DAG.getNode(ISD::TRUNCATE_USAT_U, DL, VT, USatVal);
15019+
}
15020+
15021+
return SDValue();
15022+
}
15023+
1491815024
SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
1491915025
SDValue N0 = N->getOperand(0);
1492015026
EVT VT = N->getValueType(0);
@@ -14930,6 +15036,10 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
1493015036
if (N0.getOpcode() == ISD::TRUNCATE)
1493115037
return DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
1493215038

15039+
// fold saturated truncate
15040+
if (SDValue SaturatedTR = foldToSaturated(N, VT, N0, SrcVT, DL, TLI, DAG))
15041+
return SaturatedTR;
15042+
1493315043
// fold (truncate c1) -> c1
1493415044
if (SDValue C = DAG.FoldConstantArithmetic(ISD::TRUNCATE, DL, VT, {N0}))
1493515045
return C;

llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,9 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
380380
case ISD::SIGN_EXTEND_VECTOR_INREG: return "sign_extend_vector_inreg";
381381
case ISD::ZERO_EXTEND_VECTOR_INREG: return "zero_extend_vector_inreg";
382382
case ISD::TRUNCATE: return "truncate";
383+
case ISD::TRUNCATE_SSAT_S: return "truncate_ssat_s";
384+
case ISD::TRUNCATE_SSAT_U: return "truncate_ssat_u";
385+
case ISD::TRUNCATE_USAT_U: return "truncate_usat_u";
383386
case ISD::FP_ROUND: return "fp_round";
384387
case ISD::STRICT_FP_ROUND: return "strict_fp_round";
385388
case ISD::FP_EXTEND: return "fp_extend";

llvm/lib/CodeGen/TargetLoweringBase.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,11 @@ void TargetLoweringBase::initActions() {
753753
// Absolute difference
754754
setOperationAction({ISD::ABDS, ISD::ABDU}, VT, Expand);
755755

756+
// Saturated trunc
757+
setOperationAction(ISD::TRUNCATE_SSAT_S, VT, Expand);
758+
setOperationAction(ISD::TRUNCATE_SSAT_U, VT, Expand);
759+
setOperationAction(ISD::TRUNCATE_USAT_U, VT, Expand);
760+
756761
// These default to Expand so they will be expanded to CTLZ/CTTZ by default.
757762
setOperationAction({ISD::CTLZ_ZERO_UNDEF, ISD::CTTZ_ZERO_UNDEF}, VT,
758763
Expand);

0 commit comments

Comments
 (0)