@@ -486,6 +486,7 @@ namespace {
486
486
SDValue visitSIGN_EXTEND_INREG(SDNode *N);
487
487
SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
488
488
SDValue visitTRUNCATE(SDNode *N);
489
+ SDValue visitTRUNCATE_USAT(SDNode *N);
489
490
SDValue visitBITCAST(SDNode *N);
490
491
SDValue visitFREEZE(SDNode *N);
491
492
SDValue visitBUILD_PAIR(SDNode *N);
@@ -1908,6 +1909,8 @@ SDValue DAGCombiner::visit(SDNode *N) {
1908
1909
case ISD::ZERO_EXTEND_VECTOR_INREG:
1909
1910
case ISD::ANY_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N);
1910
1911
case ISD::TRUNCATE: return visitTRUNCATE(N);
1912
+ case ISD::TRUNCATE_USAT_U:
1913
+ case ISD::TRUNCATE_SSAT_U: return visitTRUNCATE_USAT(N);
1911
1914
case ISD::BITCAST: return visitBITCAST(N);
1912
1915
case ISD::BUILD_PAIR: return visitBUILD_PAIR(N);
1913
1916
case ISD::FADD: return visitFADD(N);
@@ -13203,7 +13206,9 @@ SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
13203
13206
unsigned CastOpcode = Cast->getOpcode();
13204
13207
assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
13205
13208
CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
13206
- CastOpcode == ISD::FP_ROUND) &&
13209
+ CastOpcode == ISD::TRUNCATE_SSAT_S ||
13210
+ CastOpcode == ISD::TRUNCATE_SSAT_U ||
13211
+ CastOpcode == ISD::TRUNCATE_USAT_U || CastOpcode == ISD::FP_ROUND) &&
13207
13212
"Unexpected opcode for vector select narrowing/widening");
13208
13213
13209
13214
// We only do this transform before legal ops because the pattern may be
@@ -14915,6 +14920,159 @@ SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
14915
14920
return SDValue();
14916
14921
}
14917
14922
14923
+ SDValue DAGCombiner::visitTRUNCATE_USAT(SDNode *N) {
14924
+ EVT VT = N->getValueType(0);
14925
+ SDValue N0 = N->getOperand(0);
14926
+ SDValue FPInstr = N0.getOpcode() == ISD::SMAX ? N0.getOperand(0) : N0;
14927
+ if (FPInstr.getOpcode() == ISD::FP_TO_SINT ||
14928
+ FPInstr.getOpcode() == ISD::FP_TO_UINT) {
14929
+ EVT FPVT = FPInstr.getOperand(0).getValueType();
14930
+ if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(ISD::FP_TO_UINT_SAT,
14931
+ FPVT, VT))
14932
+ return SDValue();
14933
+ SDValue Sat = DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(FPInstr), VT,
14934
+ FPInstr.getOperand(0),
14935
+ DAG.getValueType(VT.getScalarType()));
14936
+ return Sat;
14937
+ }
14938
+
14939
+ return SDValue();
14940
+ }
14941
+
14942
+ /// Detect patterns of truncation with unsigned saturation:
14943
+ ///
14944
+ /// (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type).
14945
+ /// Return the source value x to be truncated or SDValue() if the pattern was
14946
+ /// not matched.
14947
+ ///
14948
+ static SDValue detectUSatUPattern(SDValue In, EVT VT) {
14949
+ EVT InVT = In.getValueType();
14950
+
14951
+ // Saturation with truncation. We truncate from InVT to VT.
14952
+ assert(InVT.getScalarSizeInBits() > VT.getScalarSizeInBits() &&
14953
+ "Unexpected types for truncate operation");
14954
+
14955
+ // Match min/max and return limit value as a parameter.
14956
+ auto MatchMinMax = [](SDValue V, unsigned Opcode, APInt &Limit) -> SDValue {
14957
+ if (V.getOpcode() == Opcode &&
14958
+ ISD::isConstantSplatVector(V.getOperand(1).getNode(), Limit))
14959
+ return V.getOperand(0);
14960
+ return SDValue();
14961
+ };
14962
+
14963
+ APInt C1, C2;
14964
+ if (SDValue UMin = MatchMinMax(In, ISD::UMIN, C2))
14965
+ // C2 should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according
14966
+ // the element size of the destination type.
14967
+ if (C2.isMask(VT.getScalarSizeInBits()))
14968
+ return UMin;
14969
+
14970
+ return SDValue();
14971
+ }
14972
+
14973
+ /// Detect patterns of truncation with signed saturation:
14974
+ /// (truncate (smin ((smax (x, signed_min_of_dest_type)),
14975
+ /// signed_max_of_dest_type)) to dest_type)
14976
+ /// or:
14977
+ /// (truncate (smax ((smin (x, signed_max_of_dest_type)),
14978
+ /// signed_min_of_dest_type)) to dest_type).
14979
+ /// With MatchPackUS, the smax/smin range is [0, unsigned_max_of_dest_type].
14980
+ /// Return the source value to be truncated or SDValue() if the pattern was not
14981
+ /// matched.
14982
+ static SDValue detectSSatSPattern(SDValue In, EVT VT) {
14983
+ unsigned NumDstBits = VT.getScalarSizeInBits();
14984
+ unsigned NumSrcBits = In.getScalarValueSizeInBits();
14985
+ assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
14986
+
14987
+ auto MatchMinMax = [](SDValue V, unsigned Opcode,
14988
+ const APInt &Limit) -> SDValue {
14989
+ APInt C;
14990
+ if (V.getOpcode() == Opcode &&
14991
+ ISD::isConstantSplatVector(V.getOperand(1).getNode(), C) && C == Limit)
14992
+ return V.getOperand(0);
14993
+ return SDValue();
14994
+ };
14995
+
14996
+ APInt SignedMax, SignedMin;
14997
+ SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
14998
+ SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);
14999
+ if (SDValue SMin = MatchMinMax(In, ISD::SMIN, SignedMax)) {
15000
+ if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, SignedMin)) {
15001
+ return SMax;
15002
+ }
15003
+ }
15004
+ if (SDValue SMax = MatchMinMax(In, ISD::SMAX, SignedMin)) {
15005
+ if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, SignedMax)) {
15006
+ return SMin;
15007
+ }
15008
+ }
15009
+ return SDValue();
15010
+ }
15011
+
15012
+ /// Detect patterns of truncation with unsigned saturation:
15013
+ ///
15014
+ /// (truncate (smin (smax (x, C1), C2)) to dest_type),
15015
+ /// where C1 >= 0 and C2 is unsigned max of destination type.
15016
+ ///
15017
+ /// (truncate (smax (smin (x, C2), C1)) to dest_type)
15018
+ /// where C1 >= 0, C2 is unsigned max of destination type and C1 <= C2.
15019
+ ///
15020
+ static SDValue detectSSatUPattern(SDValue In, EVT VT, SelectionDAG &DAG,
15021
+ const SDLoc &DL) {
15022
+ EVT InVT = In.getValueType();
15023
+
15024
+ // Saturation with truncation. We truncate from InVT to VT.
15025
+ assert(InVT.getScalarSizeInBits() > VT.getScalarSizeInBits() &&
15026
+ "Unexpected types for truncate operation");
15027
+
15028
+ // Match min/max and return limit value as a parameter.
15029
+ auto MatchMinMax = [](SDValue V, unsigned Opcode, APInt &Limit) -> SDValue {
15030
+ if (V.getOpcode() == Opcode &&
15031
+ ISD::isConstantSplatVector(V.getOperand(1).getNode(), Limit))
15032
+ return V.getOperand(0);
15033
+ return SDValue();
15034
+ };
15035
+
15036
+ APInt C1, C2;
15037
+ if (SDValue SMin = MatchMinMax(In, ISD::SMIN, C2))
15038
+ if (MatchMinMax(SMin, ISD::SMAX, C1))
15039
+ if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()))
15040
+ return SMin;
15041
+
15042
+ if (SDValue SMax = MatchMinMax(In, ISD::SMAX, C1))
15043
+ if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, C2))
15044
+ if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()) &&
15045
+ C2.uge(C1))
15046
+ return DAG.getNode(ISD::SMAX, DL, InVT, SMin, In.getOperand(1));
15047
+
15048
+ return SDValue();
15049
+ }
15050
+
15051
+ static SDValue foldToSaturated(SDNode *N, EVT &VT, SDValue &Src, EVT &SrcVT,
15052
+ SDLoc &DL, const TargetLowering &TLI,
15053
+ SelectionDAG &DAG) {
15054
+ if (Src.getOpcode() == ISD::SMIN || Src.getOpcode() == ISD::SMAX) {
15055
+ if (TLI.isOperationLegalOrCustom(ISD::TRUNCATE_SSAT_S, SrcVT) &&
15056
+ TLI.isTypeDesirableForOp(ISD::TRUNCATE_SSAT_S, VT)) {
15057
+ if (SDValue SSatVal = detectSSatSPattern(Src, VT))
15058
+ return DAG.getNode(ISD::TRUNCATE_SSAT_S, DL, VT, SSatVal);
15059
+ } else if (TLI.isOperationLegalOrCustom(ISD::TRUNCATE_SSAT_U, SrcVT) &&
15060
+ TLI.isTypeDesirableForOp(ISD::TRUNCATE_SSAT_U, VT)) {
15061
+ if (SDValue SSatVal = detectSSatUPattern(Src, VT, DAG, DL))
15062
+ return DAG.getNode(ISD::TRUNCATE_SSAT_S, DL, VT, SSatVal);
15063
+ }
15064
+ } else if (Src.getOpcode() == ISD::UMIN) {
15065
+ if (TLI.isOperationLegalOrCustom(ISD::TRUNCATE_USAT_U, SrcVT) &&
15066
+ TLI.isTypeDesirableForOp(ISD::TRUNCATE_USAT_U, VT)) {
15067
+ if (SDValue USatVal = detectUSatUPattern(Src, VT)) {
15068
+ return DAG.getNode(ISD::TRUNCATE_USAT_U, DL, VT, USatVal);
15069
+ }
15070
+ }
15071
+ }
15072
+
15073
+ return SDValue();
15074
+ }
15075
+
14918
15076
SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
14919
15077
SDValue N0 = N->getOperand(0);
14920
15078
EVT VT = N->getValueType(0);
@@ -14930,6 +15088,11 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
14930
15088
if (N0.getOpcode() == ISD::TRUNCATE)
14931
15089
return DAG.getNode(ISD::TRUNCATE, DL, VT, N0.getOperand(0));
14932
15090
15091
+ // fold satruated truncate
15092
+ if (SDValue SaturatedTR = foldToSaturated(N, VT, N0, SrcVT, DL, TLI, DAG)) {
15093
+ return SaturatedTR;
15094
+ }
15095
+
14933
15096
// fold (truncate c1) -> c1
14934
15097
if (SDValue C = DAG.FoldConstantArithmetic(ISD::TRUNCATE, DL, VT, {N0}))
14935
15098
return C;
0 commit comments