@@ -486,6 +486,7 @@ namespace {
486
486
SDValue visitSIGN_EXTEND_INREG(SDNode *N);
487
487
SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
488
488
SDValue visitTRUNCATE(SDNode *N);
489
+ SDValue visitTRUNCATE_USAT(SDNode *N);
489
490
SDValue visitBITCAST(SDNode *N);
490
491
SDValue visitFREEZE(SDNode *N);
491
492
SDValue visitBUILD_PAIR(SDNode *N);
@@ -13203,7 +13204,9 @@ SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
13203
13204
unsigned CastOpcode = Cast->getOpcode();
13204
13205
assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
13205
13206
CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
13206
- CastOpcode == ISD::FP_ROUND) &&
13207
+ CastOpcode == ISD::TRUNCATE_SSAT_S ||
13208
+ CastOpcode == ISD::TRUNCATE_SSAT_U ||
13209
+ CastOpcode == ISD::TRUNCATE_USAT_U || CastOpcode == ISD::FP_ROUND) &&
13207
13210
"Unexpected opcode for vector select narrowing/widening");
13208
13211
13209
13212
// We only do this transform before legal ops because the pattern may be
@@ -14915,6 +14918,109 @@ SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
14915
14918
return SDValue();
14916
14919
}
14917
14920
14921
+ /// Detect patterns of truncation with unsigned saturation:
14922
+ ///
14923
+ /// (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type).
14924
+ /// Return the source value x to be truncated or SDValue() if the pattern was
14925
+ /// not matched.
14926
+ ///
14927
+ static SDValue detectUSatUPattern(SDValue In, EVT VT) {
14928
+ unsigned NumDstBits = VT.getScalarSizeInBits();
14929
+ unsigned NumSrcBits = In.getScalarValueSizeInBits();
14930
+ // Saturation with truncation. We truncate from InVT to VT.
14931
+ assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
14932
+
14933
+ SDValue Min;
14934
+ APInt UnsignedMax = APInt::getMaxValue(NumDstBits).zext(NumSrcBits);
14935
+ if (sd_match(In, m_UMin(m_Value(Min), m_SpecificInt(UnsignedMax))))
14936
+ return Min;
14937
+
14938
+ return SDValue();
14939
+ }
14940
+
14941
+ /// Detect patterns of truncation with signed saturation:
14942
+ /// (truncate (smin (smax (x, signed_min_of_dest_type),
14943
+ /// signed_max_of_dest_type)) to dest_type)
14944
+ /// or:
14945
+ /// (truncate (smax (smin (x, signed_max_of_dest_type),
14946
+ /// signed_min_of_dest_type)) to dest_type).
14947
+ ///
14948
+ /// Return the source value to be truncated or SDValue() if the pattern was not
14949
+ /// matched.
14950
+ static SDValue detectSSatSPattern(SDValue In, EVT VT) {
14951
+ unsigned NumDstBits = VT.getScalarSizeInBits();
14952
+ unsigned NumSrcBits = In.getScalarValueSizeInBits();
14953
+ // Saturation with truncation. We truncate from InVT to VT.
14954
+ assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
14955
+
14956
+ SDValue Val;
14957
+ APInt SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
14958
+ APInt SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);
14959
+
14960
+ if (sd_match(In, m_SMin(m_SMax(m_Value(Val), m_SpecificInt(SignedMin)),
14961
+ m_SpecificInt(SignedMax))))
14962
+ return Val;
14963
+
14964
+ if (sd_match(In, m_SMax(m_SMin(m_Value(Val), m_SpecificInt(SignedMax)),
14965
+ m_SpecificInt(SignedMin))))
14966
+ return Val;
14967
+
14968
+ return SDValue();
14969
+ }
14970
+
14971
+ /// Detect patterns of truncation with unsigned saturation:
14972
+ static SDValue detectSSatUPattern(SDValue In, EVT VT, SelectionDAG &DAG,
14973
+ const SDLoc &DL) {
14974
+ unsigned NumDstBits = VT.getScalarSizeInBits();
14975
+ unsigned NumSrcBits = In.getScalarValueSizeInBits();
14976
+ // Saturation with truncation. We truncate from InVT to VT.
14977
+ assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
14978
+
14979
+ SDValue Val;
14980
+ APInt UnsignedMax = APInt::getMaxValue(NumDstBits).zext(NumSrcBits);
14981
+ // Min == 0, Max is unsigned max of destination type.
14982
+ if (sd_match(In, m_SMax(m_SMin(m_Value(Val), m_SpecificInt(UnsignedMax)),
14983
+ m_Zero())))
14984
+ return Val;
14985
+
14986
+ if (sd_match(In, m_SMin(m_SMax(m_Value(Val), m_Zero()),
14987
+ m_SpecificInt(UnsignedMax))))
14988
+ return Val;
14989
+
14990
+ if (sd_match(In, m_UMin(m_SMax(m_Value(Val), m_Zero()),
14991
+ m_SpecificInt(UnsignedMax))))
14992
+ return Val;
14993
+
14994
+ return SDValue();
14995
+ }
14996
+
14997
+ static SDValue foldToSaturated(SDNode *N, EVT &VT, SDValue &Src, EVT &SrcVT,
14998
+ SDLoc &DL, const TargetLowering &TLI,
14999
+ SelectionDAG &DAG) {
15000
+ auto AllowedTruncateSat = [&](unsigned Opc, EVT SrcVT, EVT VT) -> bool {
15001
+ return (TLI.isOperationLegalOrCustom(Opc, SrcVT) &&
15002
+ TLI.isTypeDesirableForOp(Opc, VT));
15003
+ };
15004
+
15005
+ if (Src.getOpcode() == ISD::SMIN || Src.getOpcode() == ISD::SMAX) {
15006
+ if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_S, SrcVT, VT))
15007
+ if (SDValue SSatVal = detectSSatSPattern(Src, VT))
15008
+ return DAG.getNode(ISD::TRUNCATE_SSAT_S, DL, VT, SSatVal);
15009
+ if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_U, SrcVT, VT))
15010
+ if (SDValue SSatVal = detectSSatUPattern(Src, VT, DAG, DL))
15011
+ return DAG.getNode(ISD::TRUNCATE_SSAT_U, DL, VT, SSatVal);
15012
+ } else if (Src.getOpcode() == ISD::UMIN) {
15013
+ if (AllowedTruncateSat(ISD::TRUNCATE_SSAT_U, SrcVT, VT))
15014
+ if (SDValue SSatVal = detectSSatUPattern(Src, VT, DAG, DL))
15015
+ return DAG.getNode(ISD::TRUNCATE_SSAT_U, DL, VT, SSatVal);
15016
+ if (AllowedTruncateSat(ISD::TRUNCATE_USAT_U, SrcVT, VT))
15017
+ if (SDValue USatVal = detectUSatUPattern(Src, VT))
15018
+ return DAG.getNode(ISD::TRUNCATE_USAT_U, DL, VT, USatVal);
15019
+ }
15020
+
15021
+ return SDValue();
15022
+ }
15023
+
14918
15024
SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
14919
15025
SDValue N0 = N->getOperand(0);
14920
15026
EVT VT = N->getValueType(0);
@@ -14930,6 +15036,10 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
14930
15036
if (N0.getOpcode() == ISD::TRUNCATE)
14931
15037
return DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
14932
15038
15039
+ // fold saturated truncate
15040
+ if (SDValue SaturatedTR = foldToSaturated(N, VT, N0, SrcVT, DL, TLI, DAG))
15041
+ return SaturatedTR;
15042
+
14933
15043
// fold (truncate c1) -> c1
14934
15044
if (SDValue C = DAG.FoldConstantArithmetic(ISD::TRUNCATE, DL, VT, {N0}))
14935
15045
return C;
0 commit comments