@@ -2690,6 +2690,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
2690
2690
MAKE_CASE(AArch64ISD::RSHRNB_I)
2691
2691
MAKE_CASE(AArch64ISD::CTTZ_ELTS)
2692
2692
MAKE_CASE(AArch64ISD::CALL_ARM64EC_TO_X64)
2693
+ MAKE_CASE(AArch64ISD::URSHR_I_PRED)
2693
2694
}
2694
2695
#undef MAKE_CASE
2695
2696
return nullptr;
@@ -2974,6 +2975,7 @@ static SDValue convertToScalableVector(SelectionDAG &DAG, EVT VT, SDValue V);
2974
2975
static SDValue convertFromScalableVector(SelectionDAG &DAG, EVT VT, SDValue V);
2975
2976
static SDValue convertFixedMaskToScalableVector(SDValue Mask,
2976
2977
SelectionDAG &DAG);
2978
+ static SDValue getPredicateForVector(SelectionDAG &DAG, SDLoc &DL, EVT VT);
2977
2979
static SDValue getPredicateForScalableVector(SelectionDAG &DAG, SDLoc &DL,
2978
2980
EVT VT);
2979
2981
@@ -13862,6 +13864,51 @@ SDValue AArch64TargetLowering::LowerTRUNCATE(SDValue Op,
13862
13864
return SDValue();
13863
13865
}
13864
13866
13867
+ // Check if we can we lower this SRL to a rounding shift instruction. ResVT is
13868
+ // possibly a truncated type, it tells how many bits of the value are to be
13869
+ // used.
13870
+ static bool canLowerSRLToRoundingShiftForVT(SDValue Shift, EVT ResVT,
13871
+ SelectionDAG &DAG,
13872
+ unsigned &ShiftValue,
13873
+ SDValue &RShOperand) {
13874
+ if (Shift->getOpcode() != ISD::SRL)
13875
+ return false;
13876
+
13877
+ EVT VT = Shift.getValueType();
13878
+ assert(VT.isScalableVT());
13879
+
13880
+ auto ShiftOp1 =
13881
+ dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Shift->getOperand(1)));
13882
+ if (!ShiftOp1)
13883
+ return false;
13884
+
13885
+ ShiftValue = ShiftOp1->getZExtValue();
13886
+ if (ShiftValue < 1 || ShiftValue > ResVT.getScalarSizeInBits())
13887
+ return false;
13888
+
13889
+ SDValue Add = Shift->getOperand(0);
13890
+ if (Add->getOpcode() != ISD::ADD || !Add->hasOneUse())
13891
+ return false;
13892
+
13893
+ assert(ResVT.getScalarSizeInBits() <= VT.getScalarSizeInBits() &&
13894
+ "ResVT must be truncated or same type as the shift.");
13895
+ // Check if an overflow can lead to incorrect results.
13896
+ uint64_t ExtraBits = VT.getScalarSizeInBits() - ResVT.getScalarSizeInBits();
13897
+ if (ShiftValue > ExtraBits && !Add->getFlags().hasNoUnsignedWrap())
13898
+ return false;
13899
+
13900
+ auto AddOp1 =
13901
+ dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Add->getOperand(1)));
13902
+ if (!AddOp1)
13903
+ return false;
13904
+ uint64_t AddValue = AddOp1->getZExtValue();
13905
+ if (AddValue != 1ULL << (ShiftValue - 1))
13906
+ return false;
13907
+
13908
+ RShOperand = Add->getOperand(0);
13909
+ return true;
13910
+ }
13911
+
13865
13912
SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
13866
13913
SelectionDAG &DAG) const {
13867
13914
EVT VT = Op.getValueType();
@@ -13887,6 +13934,15 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
13887
13934
Op.getOperand(0), Op.getOperand(1));
13888
13935
case ISD::SRA:
13889
13936
case ISD::SRL:
13937
+ if (VT.isScalableVector() && Subtarget->hasSVE2orSME()) {
13938
+ SDValue RShOperand;
13939
+ unsigned ShiftValue;
13940
+ if (canLowerSRLToRoundingShiftForVT(Op, VT, DAG, ShiftValue, RShOperand))
13941
+ return DAG.getNode(AArch64ISD::URSHR_I_PRED, DL, VT,
13942
+ getPredicateForVector(DAG, DL, VT), RShOperand,
13943
+ DAG.getTargetConstant(ShiftValue, DL, MVT::i32));
13944
+ }
13945
+
13890
13946
if (VT.isScalableVector() ||
13891
13947
useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable())) {
13892
13948
unsigned Opc = Op.getOpcode() == ISD::SRA ? AArch64ISD::SRA_PRED
@@ -17711,9 +17767,6 @@ static SDValue performReinterpretCastCombine(SDNode *N) {
17711
17767
17712
17768
static SDValue performSVEAndCombine(SDNode *N,
17713
17769
TargetLowering::DAGCombinerInfo &DCI) {
17714
- if (DCI.isBeforeLegalizeOps())
17715
- return SDValue();
17716
-
17717
17770
SelectionDAG &DAG = DCI.DAG;
17718
17771
SDValue Src = N->getOperand(0);
17719
17772
unsigned Opc = Src->getOpcode();
@@ -17769,6 +17822,9 @@ static SDValue performSVEAndCombine(SDNode *N,
17769
17822
return DAG.getNode(Opc, DL, N->getValueType(0), And);
17770
17823
}
17771
17824
17825
+ if (DCI.isBeforeLegalizeOps())
17826
+ return SDValue();
17827
+
17772
17828
// If both sides of AND operations are i1 splat_vectors then
17773
17829
// we can produce just i1 splat_vector as the result.
17774
17830
if (isAllActivePredicate(DAG, N->getOperand(0)))
@@ -20216,6 +20272,9 @@ static SDValue performIntrinsicCombine(SDNode *N,
20216
20272
case Intrinsic::aarch64_sve_uqsub_x:
20217
20273
return DAG.getNode(ISD::USUBSAT, SDLoc(N), N->getValueType(0),
20218
20274
N->getOperand(1), N->getOperand(2));
20275
+ case Intrinsic::aarch64_sve_urshr:
20276
+ return DAG.getNode(AArch64ISD::URSHR_I_PRED, SDLoc(N), N->getValueType(0),
20277
+ N->getOperand(1), N->getOperand(2), N->getOperand(3));
20219
20278
case Intrinsic::aarch64_sve_asrd:
20220
20279
return DAG.getNode(AArch64ISD::SRAD_MERGE_OP1, SDLoc(N), N->getValueType(0),
20221
20280
N->getOperand(1), N->getOperand(2), N->getOperand(3));
@@ -20832,6 +20891,51 @@ static SDValue performUnpackCombine(SDNode *N, SelectionDAG &DAG,
20832
20891
return SDValue();
20833
20892
}
20834
20893
20894
+ static bool isHalvingTruncateAndConcatOfLegalIntScalableType(SDNode *N) {
20895
+ if (N->getOpcode() != AArch64ISD::UZP1)
20896
+ return false;
20897
+ SDValue Op0 = N->getOperand(0);
20898
+ EVT SrcVT = Op0->getValueType(0);
20899
+ EVT DstVT = N->getValueType(0);
20900
+ return (SrcVT == MVT::nxv8i16 && DstVT == MVT::nxv16i8) ||
20901
+ (SrcVT == MVT::nxv4i32 && DstVT == MVT::nxv8i16) ||
20902
+ (SrcVT == MVT::nxv2i64 && DstVT == MVT::nxv4i32);
20903
+ }
20904
+
20905
+ // Try to combine rounding shifts where the operands come from an extend, and
20906
+ // the result is truncated and combined into one vector.
20907
+ // uzp1(rshrnb(uunpklo(X),C), rshrnb(uunpkhi(X), C)) -> urshr(X, C)
20908
+ static SDValue tryCombineExtendRShTrunc(SDNode *N, SelectionDAG &DAG) {
20909
+ assert(N->getOpcode() == AArch64ISD::UZP1 && "Only UZP1 expected.");
20910
+ SDValue Op0 = N->getOperand(0);
20911
+ SDValue Op1 = N->getOperand(1);
20912
+ EVT ResVT = N->getValueType(0);
20913
+
20914
+ unsigned RshOpc = Op0.getOpcode();
20915
+ if (RshOpc != AArch64ISD::RSHRNB_I)
20916
+ return SDValue();
20917
+
20918
+ // Same op code and imm value?
20919
+ SDValue ShiftValue = Op0.getOperand(1);
20920
+ if (RshOpc != Op1.getOpcode() || ShiftValue != Op1.getOperand(1))
20921
+ return SDValue();
20922
+
20923
+ // Same unextended operand value?
20924
+ SDValue Lo = Op0.getOperand(0);
20925
+ SDValue Hi = Op1.getOperand(0);
20926
+ if (Lo.getOpcode() != AArch64ISD::UUNPKLO &&
20927
+ Hi.getOpcode() != AArch64ISD::UUNPKHI)
20928
+ return SDValue();
20929
+ SDValue OrigArg = Lo.getOperand(0);
20930
+ if (OrigArg != Hi.getOperand(0))
20931
+ return SDValue();
20932
+
20933
+ SDLoc DL(N);
20934
+ return DAG.getNode(AArch64ISD::URSHR_I_PRED, DL, ResVT,
20935
+ getPredicateForVector(DAG, DL, ResVT), OrigArg,
20936
+ ShiftValue);
20937
+ }
20938
+
20835
20939
// Try to simplify:
20836
20940
// t1 = nxv8i16 add(X, 1 << (ShiftValue - 1))
20837
20941
// t2 = nxv8i16 srl(t1, ShiftValue)
@@ -20844,9 +20948,7 @@ static SDValue performUnpackCombine(SDNode *N, SelectionDAG &DAG,
20844
20948
static SDValue trySimplifySrlAddToRshrnb(SDValue Srl, SelectionDAG &DAG,
20845
20949
const AArch64Subtarget *Subtarget) {
20846
20950
EVT VT = Srl->getValueType(0);
20847
-
20848
- if (!VT.isScalableVector() || !Subtarget->hasSVE2() ||
20849
- Srl->getOpcode() != ISD::SRL)
20951
+ if (!VT.isScalableVector() || !Subtarget->hasSVE2())
20850
20952
return SDValue();
20851
20953
20852
20954
EVT ResVT;
@@ -20859,29 +20961,14 @@ static SDValue trySimplifySrlAddToRshrnb(SDValue Srl, SelectionDAG &DAG,
20859
20961
else
20860
20962
return SDValue();
20861
20963
20862
- auto SrlOp1 =
20863
- dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Srl->getOperand(1)));
20864
- if (!SrlOp1)
20865
- return SDValue();
20866
- unsigned ShiftValue = SrlOp1->getZExtValue();
20867
- if (ShiftValue < 1 || ShiftValue > ResVT.getScalarSizeInBits())
20868
- return SDValue();
20869
-
20870
- SDValue Add = Srl->getOperand(0);
20871
- if (Add->getOpcode() != ISD::ADD || !Add->hasOneUse())
20872
- return SDValue();
20873
- auto AddOp1 =
20874
- dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Add->getOperand(1)));
20875
- if (!AddOp1)
20876
- return SDValue();
20877
- uint64_t AddValue = AddOp1->getZExtValue();
20878
- if (AddValue != 1ULL << (ShiftValue - 1))
20879
- return SDValue();
20880
-
20881
20964
SDLoc DL(Srl);
20965
+ unsigned ShiftValue;
20966
+ SDValue RShOperand;
20967
+ if (!canLowerSRLToRoundingShiftForVT(Srl, ResVT, DAG, ShiftValue, RShOperand))
20968
+ return SDValue();
20882
20969
SDValue Rshrnb = DAG.getNode(
20883
20970
AArch64ISD::RSHRNB_I, DL, ResVT,
20884
- {Add->getOperand(0) , DAG.getTargetConstant(ShiftValue, DL, MVT::i32)});
20971
+ {RShOperand , DAG.getTargetConstant(ShiftValue, DL, MVT::i32)});
20885
20972
return DAG.getNode(ISD::BITCAST, DL, VT, Rshrnb);
20886
20973
}
20887
20974
@@ -20919,6 +21006,9 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
20919
21006
}
20920
21007
}
20921
21008
21009
+ if (SDValue Urshr = tryCombineExtendRShTrunc(N, DAG))
21010
+ return Urshr;
21011
+
20922
21012
if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Op0, DAG, Subtarget))
20923
21013
return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Rshrnb, Op1);
20924
21014
@@ -20949,6 +21039,19 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
20949
21039
if (!IsLittleEndian)
20950
21040
return SDValue();
20951
21041
21042
+ // uzp1(bitcast(x), bitcast(y)) -> uzp1(x, y)
21043
+ // Example:
21044
+ // nxv4i32 = uzp1 bitcast(nxv4i32 x to nxv2i64), bitcast(nxv4i32 y to nxv2i64)
21045
+ // to
21046
+ // nxv4i32 = uzp1 nxv4i32 x, nxv4i32 y
21047
+ if (isHalvingTruncateAndConcatOfLegalIntScalableType(N) &&
21048
+ Op0.getOpcode() == ISD::BITCAST && Op1.getOpcode() == ISD::BITCAST) {
21049
+ if (Op0.getOperand(0).getValueType() == Op1.getOperand(0).getValueType()) {
21050
+ return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0.getOperand(0),
21051
+ Op1.getOperand(0));
21052
+ }
21053
+ }
21054
+
20952
21055
if (ResVT != MVT::v2i32 && ResVT != MVT::v4i16 && ResVT != MVT::v8i8)
20953
21056
return SDValue();
20954
21057
0 commit comments