Skip to content

Commit 1d14323

Browse files
authored
[AArch64][SVE2] Generate urshr rounding shift rights (#78374)
Add a new node `AArch64ISD::URSHR_I_PRED`. `srl(add(X, 1 << (ShiftValue - 1)), ShiftValue)` is transformed to `urshr`, or to `rshrnb` (as before) if the result it truncated. `uzp1(rshrnb(uunpklo(X),C), rshrnb(uunpkhi(X), C))` is converted to `urshr(X, C)` (tested by the wide_trunc tests). Pattern matching code in `canLowerSRLToRoundingShiftForVT` is taken from prior code in rshrnb. It returns true if the add has NUW or if the number of bits used in the return value allow us to not care about the overflow (tested by rshrnb test cases).
1 parent 56e241a commit 1d14323

File tree

4 files changed

+412
-28
lines changed

4 files changed

+412
-28
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 129 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2690,6 +2690,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
26902690
MAKE_CASE(AArch64ISD::RSHRNB_I)
26912691
MAKE_CASE(AArch64ISD::CTTZ_ELTS)
26922692
MAKE_CASE(AArch64ISD::CALL_ARM64EC_TO_X64)
2693+
MAKE_CASE(AArch64ISD::URSHR_I_PRED)
26932694
}
26942695
#undef MAKE_CASE
26952696
return nullptr;
@@ -2974,6 +2975,7 @@ static SDValue convertToScalableVector(SelectionDAG &DAG, EVT VT, SDValue V);
29742975
static SDValue convertFromScalableVector(SelectionDAG &DAG, EVT VT, SDValue V);
29752976
static SDValue convertFixedMaskToScalableVector(SDValue Mask,
29762977
SelectionDAG &DAG);
2978+
static SDValue getPredicateForVector(SelectionDAG &DAG, SDLoc &DL, EVT VT);
29772979
static SDValue getPredicateForScalableVector(SelectionDAG &DAG, SDLoc &DL,
29782980
EVT VT);
29792981

@@ -13862,6 +13864,51 @@ SDValue AArch64TargetLowering::LowerTRUNCATE(SDValue Op,
1386213864
return SDValue();
1386313865
}
1386413866

13867+
// Check if we can we lower this SRL to a rounding shift instruction. ResVT is
13868+
// possibly a truncated type, it tells how many bits of the value are to be
13869+
// used.
13870+
static bool canLowerSRLToRoundingShiftForVT(SDValue Shift, EVT ResVT,
13871+
SelectionDAG &DAG,
13872+
unsigned &ShiftValue,
13873+
SDValue &RShOperand) {
13874+
if (Shift->getOpcode() != ISD::SRL)
13875+
return false;
13876+
13877+
EVT VT = Shift.getValueType();
13878+
assert(VT.isScalableVT());
13879+
13880+
auto ShiftOp1 =
13881+
dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Shift->getOperand(1)));
13882+
if (!ShiftOp1)
13883+
return false;
13884+
13885+
ShiftValue = ShiftOp1->getZExtValue();
13886+
if (ShiftValue < 1 || ShiftValue > ResVT.getScalarSizeInBits())
13887+
return false;
13888+
13889+
SDValue Add = Shift->getOperand(0);
13890+
if (Add->getOpcode() != ISD::ADD || !Add->hasOneUse())
13891+
return false;
13892+
13893+
assert(ResVT.getScalarSizeInBits() <= VT.getScalarSizeInBits() &&
13894+
"ResVT must be truncated or same type as the shift.");
13895+
// Check if an overflow can lead to incorrect results.
13896+
uint64_t ExtraBits = VT.getScalarSizeInBits() - ResVT.getScalarSizeInBits();
13897+
if (ShiftValue > ExtraBits && !Add->getFlags().hasNoUnsignedWrap())
13898+
return false;
13899+
13900+
auto AddOp1 =
13901+
dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Add->getOperand(1)));
13902+
if (!AddOp1)
13903+
return false;
13904+
uint64_t AddValue = AddOp1->getZExtValue();
13905+
if (AddValue != 1ULL << (ShiftValue - 1))
13906+
return false;
13907+
13908+
RShOperand = Add->getOperand(0);
13909+
return true;
13910+
}
13911+
1386513912
SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
1386613913
SelectionDAG &DAG) const {
1386713914
EVT VT = Op.getValueType();
@@ -13887,6 +13934,15 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
1388713934
Op.getOperand(0), Op.getOperand(1));
1388813935
case ISD::SRA:
1388913936
case ISD::SRL:
13937+
if (VT.isScalableVector() && Subtarget->hasSVE2orSME()) {
13938+
SDValue RShOperand;
13939+
unsigned ShiftValue;
13940+
if (canLowerSRLToRoundingShiftForVT(Op, VT, DAG, ShiftValue, RShOperand))
13941+
return DAG.getNode(AArch64ISD::URSHR_I_PRED, DL, VT,
13942+
getPredicateForVector(DAG, DL, VT), RShOperand,
13943+
DAG.getTargetConstant(ShiftValue, DL, MVT::i32));
13944+
}
13945+
1389013946
if (VT.isScalableVector() ||
1389113947
useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable())) {
1389213948
unsigned Opc = Op.getOpcode() == ISD::SRA ? AArch64ISD::SRA_PRED
@@ -17711,9 +17767,6 @@ static SDValue performReinterpretCastCombine(SDNode *N) {
1771117767

1771217768
static SDValue performSVEAndCombine(SDNode *N,
1771317769
TargetLowering::DAGCombinerInfo &DCI) {
17714-
if (DCI.isBeforeLegalizeOps())
17715-
return SDValue();
17716-
1771717770
SelectionDAG &DAG = DCI.DAG;
1771817771
SDValue Src = N->getOperand(0);
1771917772
unsigned Opc = Src->getOpcode();
@@ -17769,6 +17822,9 @@ static SDValue performSVEAndCombine(SDNode *N,
1776917822
return DAG.getNode(Opc, DL, N->getValueType(0), And);
1777017823
}
1777117824

17825+
if (DCI.isBeforeLegalizeOps())
17826+
return SDValue();
17827+
1777217828
// If both sides of AND operations are i1 splat_vectors then
1777317829
// we can produce just i1 splat_vector as the result.
1777417830
if (isAllActivePredicate(DAG, N->getOperand(0)))
@@ -20216,6 +20272,9 @@ static SDValue performIntrinsicCombine(SDNode *N,
2021620272
case Intrinsic::aarch64_sve_uqsub_x:
2021720273
return DAG.getNode(ISD::USUBSAT, SDLoc(N), N->getValueType(0),
2021820274
N->getOperand(1), N->getOperand(2));
20275+
case Intrinsic::aarch64_sve_urshr:
20276+
return DAG.getNode(AArch64ISD::URSHR_I_PRED, SDLoc(N), N->getValueType(0),
20277+
N->getOperand(1), N->getOperand(2), N->getOperand(3));
2021920278
case Intrinsic::aarch64_sve_asrd:
2022020279
return DAG.getNode(AArch64ISD::SRAD_MERGE_OP1, SDLoc(N), N->getValueType(0),
2022120280
N->getOperand(1), N->getOperand(2), N->getOperand(3));
@@ -20832,6 +20891,51 @@ static SDValue performUnpackCombine(SDNode *N, SelectionDAG &DAG,
2083220891
return SDValue();
2083320892
}
2083420893

20894+
static bool isHalvingTruncateAndConcatOfLegalIntScalableType(SDNode *N) {
20895+
if (N->getOpcode() != AArch64ISD::UZP1)
20896+
return false;
20897+
SDValue Op0 = N->getOperand(0);
20898+
EVT SrcVT = Op0->getValueType(0);
20899+
EVT DstVT = N->getValueType(0);
20900+
return (SrcVT == MVT::nxv8i16 && DstVT == MVT::nxv16i8) ||
20901+
(SrcVT == MVT::nxv4i32 && DstVT == MVT::nxv8i16) ||
20902+
(SrcVT == MVT::nxv2i64 && DstVT == MVT::nxv4i32);
20903+
}
20904+
20905+
// Try to combine rounding shifts where the operands come from an extend, and
20906+
// the result is truncated and combined into one vector.
20907+
// uzp1(rshrnb(uunpklo(X),C), rshrnb(uunpkhi(X), C)) -> urshr(X, C)
20908+
static SDValue tryCombineExtendRShTrunc(SDNode *N, SelectionDAG &DAG) {
20909+
assert(N->getOpcode() == AArch64ISD::UZP1 && "Only UZP1 expected.");
20910+
SDValue Op0 = N->getOperand(0);
20911+
SDValue Op1 = N->getOperand(1);
20912+
EVT ResVT = N->getValueType(0);
20913+
20914+
unsigned RshOpc = Op0.getOpcode();
20915+
if (RshOpc != AArch64ISD::RSHRNB_I)
20916+
return SDValue();
20917+
20918+
// Same op code and imm value?
20919+
SDValue ShiftValue = Op0.getOperand(1);
20920+
if (RshOpc != Op1.getOpcode() || ShiftValue != Op1.getOperand(1))
20921+
return SDValue();
20922+
20923+
// Same unextended operand value?
20924+
SDValue Lo = Op0.getOperand(0);
20925+
SDValue Hi = Op1.getOperand(0);
20926+
if (Lo.getOpcode() != AArch64ISD::UUNPKLO &&
20927+
Hi.getOpcode() != AArch64ISD::UUNPKHI)
20928+
return SDValue();
20929+
SDValue OrigArg = Lo.getOperand(0);
20930+
if (OrigArg != Hi.getOperand(0))
20931+
return SDValue();
20932+
20933+
SDLoc DL(N);
20934+
return DAG.getNode(AArch64ISD::URSHR_I_PRED, DL, ResVT,
20935+
getPredicateForVector(DAG, DL, ResVT), OrigArg,
20936+
ShiftValue);
20937+
}
20938+
2083520939
// Try to simplify:
2083620940
// t1 = nxv8i16 add(X, 1 << (ShiftValue - 1))
2083720941
// t2 = nxv8i16 srl(t1, ShiftValue)
@@ -20844,9 +20948,7 @@ static SDValue performUnpackCombine(SDNode *N, SelectionDAG &DAG,
2084420948
static SDValue trySimplifySrlAddToRshrnb(SDValue Srl, SelectionDAG &DAG,
2084520949
const AArch64Subtarget *Subtarget) {
2084620950
EVT VT = Srl->getValueType(0);
20847-
20848-
if (!VT.isScalableVector() || !Subtarget->hasSVE2() ||
20849-
Srl->getOpcode() != ISD::SRL)
20951+
if (!VT.isScalableVector() || !Subtarget->hasSVE2())
2085020952
return SDValue();
2085120953

2085220954
EVT ResVT;
@@ -20859,29 +20961,14 @@ static SDValue trySimplifySrlAddToRshrnb(SDValue Srl, SelectionDAG &DAG,
2085920961
else
2086020962
return SDValue();
2086120963

20862-
auto SrlOp1 =
20863-
dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Srl->getOperand(1)));
20864-
if (!SrlOp1)
20865-
return SDValue();
20866-
unsigned ShiftValue = SrlOp1->getZExtValue();
20867-
if (ShiftValue < 1 || ShiftValue > ResVT.getScalarSizeInBits())
20868-
return SDValue();
20869-
20870-
SDValue Add = Srl->getOperand(0);
20871-
if (Add->getOpcode() != ISD::ADD || !Add->hasOneUse())
20872-
return SDValue();
20873-
auto AddOp1 =
20874-
dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Add->getOperand(1)));
20875-
if (!AddOp1)
20876-
return SDValue();
20877-
uint64_t AddValue = AddOp1->getZExtValue();
20878-
if (AddValue != 1ULL << (ShiftValue - 1))
20879-
return SDValue();
20880-
2088120964
SDLoc DL(Srl);
20965+
unsigned ShiftValue;
20966+
SDValue RShOperand;
20967+
if (!canLowerSRLToRoundingShiftForVT(Srl, ResVT, DAG, ShiftValue, RShOperand))
20968+
return SDValue();
2088220969
SDValue Rshrnb = DAG.getNode(
2088320970
AArch64ISD::RSHRNB_I, DL, ResVT,
20884-
{Add->getOperand(0), DAG.getTargetConstant(ShiftValue, DL, MVT::i32)});
20971+
{RShOperand, DAG.getTargetConstant(ShiftValue, DL, MVT::i32)});
2088520972
return DAG.getNode(ISD::BITCAST, DL, VT, Rshrnb);
2088620973
}
2088720974

@@ -20919,6 +21006,9 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
2091921006
}
2092021007
}
2092121008

21009+
if (SDValue Urshr = tryCombineExtendRShTrunc(N, DAG))
21010+
return Urshr;
21011+
2092221012
if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Op0, DAG, Subtarget))
2092321013
return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Rshrnb, Op1);
2092421014

@@ -20949,6 +21039,19 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
2094921039
if (!IsLittleEndian)
2095021040
return SDValue();
2095121041

21042+
// uzp1(bitcast(x), bitcast(y)) -> uzp1(x, y)
21043+
// Example:
21044+
// nxv4i32 = uzp1 bitcast(nxv4i32 x to nxv2i64), bitcast(nxv4i32 y to nxv2i64)
21045+
// to
21046+
// nxv4i32 = uzp1 nxv4i32 x, nxv4i32 y
21047+
if (isHalvingTruncateAndConcatOfLegalIntScalableType(N) &&
21048+
Op0.getOpcode() == ISD::BITCAST && Op1.getOpcode() == ISD::BITCAST) {
21049+
if (Op0.getOperand(0).getValueType() == Op1.getOperand(0).getValueType()) {
21050+
return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0.getOperand(0),
21051+
Op1.getOperand(0));
21052+
}
21053+
}
21054+
2095221055
if (ResVT != MVT::v2i32 && ResVT != MVT::v4i16 && ResVT != MVT::v8i8)
2095321056
return SDValue();
2095421057

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ enum NodeType : unsigned {
218218
SQSHLU_I,
219219
SRSHR_I,
220220
URSHR_I,
221+
URSHR_I_PRED,
221222

222223
// Vector narrowing shift by immediate (bottom)
223224
RSHRNB_I,

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ def SDT_AArch64Arith_Imm : SDTypeProfile<1, 3, [
232232
]>;
233233

234234
def AArch64asrd_m1 : SDNode<"AArch64ISD::SRAD_MERGE_OP1", SDT_AArch64Arith_Imm>;
235+
def AArch64urshri_p : SDNode<"AArch64ISD::URSHR_I_PRED", SDT_AArch64Arith_Imm>;
235236

236237
def SDT_AArch64IntExtend : SDTypeProfile<1, 4, [
237238
SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisVT<3, OtherVT>, SDTCisVec<4>,
@@ -3539,7 +3540,7 @@ let Predicates = [HasSVE2orSME] in {
35393540
defm SQSHL_ZPmI : sve_int_bin_pred_shift_imm_left_dup<0b0110, "sqshl", "SQSHL_ZPZI", int_aarch64_sve_sqshl>;
35403541
defm UQSHL_ZPmI : sve_int_bin_pred_shift_imm_left_dup<0b0111, "uqshl", "UQSHL_ZPZI", int_aarch64_sve_uqshl>;
35413542
defm SRSHR_ZPmI : sve_int_bin_pred_shift_imm_right< 0b1100, "srshr", "SRSHR_ZPZI", int_aarch64_sve_srshr>;
3542-
defm URSHR_ZPmI : sve_int_bin_pred_shift_imm_right< 0b1101, "urshr", "URSHR_ZPZI", int_aarch64_sve_urshr>;
3543+
defm URSHR_ZPmI : sve_int_bin_pred_shift_imm_right< 0b1101, "urshr", "URSHR_ZPZI", AArch64urshri_p>;
35433544
defm SQSHLU_ZPmI : sve_int_bin_pred_shift_imm_left< 0b1111, "sqshlu", "SQSHLU_ZPZI", int_aarch64_sve_sqshlu>;
35443545

35453546
// SVE2 integer add/subtract long
@@ -3584,7 +3585,7 @@ let Predicates = [HasSVE2orSME] in {
35843585
defm SSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b00, "ssra", AArch64ssra>;
35853586
defm USRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b01, "usra", AArch64usra>;
35863587
defm SRSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b10, "srsra", int_aarch64_sve_srsra, int_aarch64_sve_srshr>;
3587-
defm URSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b11, "ursra", int_aarch64_sve_ursra, int_aarch64_sve_urshr>;
3588+
defm URSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b11, "ursra", int_aarch64_sve_ursra, AArch64urshri_p>;
35883589

35893590
// SVE2 complex integer add
35903591
defm CADD_ZZI : sve2_int_cadd<0b0, "cadd", int_aarch64_sve_cadd_x>;

0 commit comments

Comments
 (0)