@@ -2649,6 +2649,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
2649
2649
MAKE_CASE(AArch64ISD::MSRR)
2650
2650
MAKE_CASE(AArch64ISD::RSHRNB_I)
2651
2651
MAKE_CASE(AArch64ISD::CTTZ_ELTS)
2652
+ MAKE_CASE(AArch64ISD::SRSHR_I_PRED)
2653
+ MAKE_CASE(AArch64ISD::URSHR_I_PRED)
2652
2654
}
2653
2655
#undef MAKE_CASE
2654
2656
return nullptr;
@@ -2933,6 +2935,7 @@ static SDValue convertToScalableVector(SelectionDAG &DAG, EVT VT, SDValue V);
2933
2935
static SDValue convertFromScalableVector(SelectionDAG &DAG, EVT VT, SDValue V);
2934
2936
static SDValue convertFixedMaskToScalableVector(SDValue Mask,
2935
2937
SelectionDAG &DAG);
2938
+ static SDValue getPredicateForVector(SelectionDAG &DAG, SDLoc &DL, EVT VT);
2936
2939
static SDValue getPredicateForScalableVector(SelectionDAG &DAG, SDLoc &DL,
2937
2940
EVT VT);
2938
2941
@@ -13713,6 +13716,42 @@ SDValue AArch64TargetLowering::LowerTRUNCATE(SDValue Op,
13713
13716
return SDValue();
13714
13717
}
13715
13718
13719
+ static SDValue tryLowerToRoundingShiftRightByImm(SDValue Shift,
13720
+ SelectionDAG &DAG) {
13721
+ if (Shift->getOpcode() != ISD::SRL && Shift->getOpcode() != ISD::SRA)
13722
+ return SDValue();
13723
+
13724
+ EVT ResVT = Shift.getValueType();
13725
+ assert(ResVT.isScalableVT());
13726
+
13727
+ auto ShiftOp1 =
13728
+ dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Shift->getOperand(1)));
13729
+ if (!ShiftOp1)
13730
+ return SDValue();
13731
+ unsigned ShiftValue = ShiftOp1->getZExtValue();
13732
+
13733
+ if (ShiftValue < 1 || ShiftValue > ResVT.getScalarSizeInBits())
13734
+ return SDValue();
13735
+
13736
+ SDValue Add = Shift->getOperand(0);
13737
+ if (Add->getOpcode() != ISD::ADD || !Add->hasOneUse())
13738
+ return SDValue();
13739
+ auto AddOp1 =
13740
+ dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Add->getOperand(1)));
13741
+ if (!AddOp1)
13742
+ return SDValue();
13743
+ uint64_t AddValue = AddOp1->getZExtValue();
13744
+ if (AddValue != 1ULL << (ShiftValue - 1))
13745
+ return SDValue();
13746
+
13747
+ SDLoc DL(Shift);
13748
+ unsigned Opc = Shift->getOpcode() == ISD::SRA ? AArch64ISD::SRSHR_I_PRED
13749
+ : AArch64ISD::URSHR_I_PRED;
13750
+ return DAG.getNode(Opc, DL, ResVT, getPredicateForVector(DAG, DL, ResVT),
13751
+ Add->getOperand(0),
13752
+ DAG.getTargetConstant(ShiftValue, DL, MVT::i32));
13753
+ }
13754
+
13716
13755
SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
13717
13756
SelectionDAG &DAG) const {
13718
13757
EVT VT = Op.getValueType();
@@ -13738,6 +13777,10 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
13738
13777
Op.getOperand(0), Op.getOperand(1));
13739
13778
case ISD::SRA:
13740
13779
case ISD::SRL:
13780
+ if (VT.isScalableVector() && Subtarget->hasSVE2orSME())
13781
+ if (SDValue RSH = tryLowerToRoundingShiftRightByImm(Op, DAG))
13782
+ return RSH;
13783
+
13741
13784
if (VT.isScalableVector() ||
13742
13785
useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable())) {
13743
13786
unsigned Opc = Op.getOpcode() == ISD::SRA ? AArch64ISD::SRA_PRED
@@ -20025,6 +20068,12 @@ static SDValue performIntrinsicCombine(SDNode *N,
20025
20068
case Intrinsic::aarch64_sve_uqsub_x:
20026
20069
return DAG.getNode(ISD::USUBSAT, SDLoc(N), N->getValueType(0),
20027
20070
N->getOperand(1), N->getOperand(2));
20071
+ case Intrinsic::aarch64_sve_srshr:
20072
+ return DAG.getNode(AArch64ISD::SRSHR_I_PRED, SDLoc(N), N->getValueType(0),
20073
+ N->getOperand(1), N->getOperand(2), N->getOperand(3));
20074
+ case Intrinsic::aarch64_sve_urshr:
20075
+ return DAG.getNode(AArch64ISD::URSHR_I_PRED, SDLoc(N), N->getValueType(0),
20076
+ N->getOperand(1), N->getOperand(2), N->getOperand(3));
20028
20077
case Intrinsic::aarch64_sve_asrd:
20029
20078
return DAG.getNode(AArch64ISD::SRAD_MERGE_OP1, SDLoc(N), N->getValueType(0),
20030
20079
N->getOperand(1), N->getOperand(2), N->getOperand(3));
@@ -20652,12 +20701,13 @@ static SDValue performUnpackCombine(SDNode *N, SelectionDAG &DAG,
20652
20701
// a uzp1 or a truncating store.
20653
20702
static SDValue trySimplifySrlAddToRshrnb(SDValue Srl, SelectionDAG &DAG,
20654
20703
const AArch64Subtarget *Subtarget) {
20655
- EVT VT = Srl->getValueType(0);
20704
+ if (Srl->getOpcode() != AArch64ISD::URSHR_I_PRED)
20705
+ return SDValue();
20656
20706
20657
- if (!VT.isScalableVector() || !Subtarget->hasSVE2() ||
20658
- Srl->getOpcode() != ISD::SRL)
20707
+ if (!isAllActivePredicate(DAG, Srl.getOperand(0)))
20659
20708
return SDValue();
20660
20709
20710
+ EVT VT = Srl->getValueType(0);
20661
20711
EVT ResVT;
20662
20712
if (VT == MVT::nxv8i16)
20663
20713
ResVT = MVT::nxv16i8;
@@ -20668,29 +20718,14 @@ static SDValue trySimplifySrlAddToRshrnb(SDValue Srl, SelectionDAG &DAG,
20668
20718
else
20669
20719
return SDValue();
20670
20720
20671
- auto SrlOp1 =
20672
- dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Srl->getOperand(1)));
20673
- if (!SrlOp1)
20674
- return SDValue();
20675
- unsigned ShiftValue = SrlOp1->getZExtValue();
20676
- if (ShiftValue < 1 || ShiftValue > ResVT.getScalarSizeInBits())
20677
- return SDValue();
20678
-
20679
- SDValue Add = Srl->getOperand(0);
20680
- if (Add->getOpcode() != ISD::ADD || !Add->hasOneUse())
20681
- return SDValue();
20682
- auto AddOp1 =
20683
- dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Add->getOperand(1)));
20684
- if (!AddOp1)
20685
- return SDValue();
20686
- uint64_t AddValue = AddOp1->getZExtValue();
20687
- if (AddValue != 1ULL << (ShiftValue - 1))
20721
+ unsigned ShiftValue =
20722
+ cast<ConstantSDNode>(Srl->getOperand(2))->getZExtValue();
20723
+ if (ShiftValue > ResVT.getScalarSizeInBits())
20688
20724
return SDValue();
20689
20725
20690
20726
SDLoc DL(Srl);
20691
- SDValue Rshrnb = DAG.getNode(
20692
- AArch64ISD::RSHRNB_I, DL, ResVT,
20693
- {Add->getOperand(0), DAG.getTargetConstant(ShiftValue, DL, MVT::i32)});
20727
+ SDValue Rshrnb = DAG.getNode(AArch64ISD::RSHRNB_I, DL, ResVT,
20728
+ {Srl->getOperand(1), Srl->getOperand(2)});
20694
20729
return DAG.getNode(ISD::BITCAST, DL, VT, Rshrnb);
20695
20730
}
20696
20731
0 commit comments