Skip to content

Commit b33d000

Browse files
committed
[AArch64][SVE2] Generate signed/unsigned rounding shift rights
Matching code is similar to that for rshrnb except that immediate shift value has a larger range, and support for signed shift. rshrnb now uses the new AArch64ISD node for uniform rounding. Change-Id: Idbb811f318d33c7637371cf7bb00285d20e1771d
1 parent cd753c7 commit b33d000

File tree

5 files changed

+276
-37
lines changed

5 files changed

+276
-37
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 58 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2649,6 +2649,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
26492649
MAKE_CASE(AArch64ISD::MSRR)
26502650
MAKE_CASE(AArch64ISD::RSHRNB_I)
26512651
MAKE_CASE(AArch64ISD::CTTZ_ELTS)
2652+
MAKE_CASE(AArch64ISD::SRSHR_I_PRED)
2653+
MAKE_CASE(AArch64ISD::URSHR_I_PRED)
26522654
}
26532655
#undef MAKE_CASE
26542656
return nullptr;
@@ -2933,6 +2935,7 @@ static SDValue convertToScalableVector(SelectionDAG &DAG, EVT VT, SDValue V);
29332935
static SDValue convertFromScalableVector(SelectionDAG &DAG, EVT VT, SDValue V);
29342936
static SDValue convertFixedMaskToScalableVector(SDValue Mask,
29352937
SelectionDAG &DAG);
2938+
static SDValue getPredicateForVector(SelectionDAG &DAG, SDLoc &DL, EVT VT);
29362939
static SDValue getPredicateForScalableVector(SelectionDAG &DAG, SDLoc &DL,
29372940
EVT VT);
29382941

@@ -13713,6 +13716,42 @@ SDValue AArch64TargetLowering::LowerTRUNCATE(SDValue Op,
1371313716
return SDValue();
1371413717
}
1371513718

13719+
static SDValue tryLowerToRoundingShiftRightByImm(SDValue Shift,
13720+
SelectionDAG &DAG) {
13721+
if (Shift->getOpcode() != ISD::SRL && Shift->getOpcode() != ISD::SRA)
13722+
return SDValue();
13723+
13724+
EVT ResVT = Shift.getValueType();
13725+
assert(ResVT.isScalableVT());
13726+
13727+
auto ShiftOp1 =
13728+
dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Shift->getOperand(1)));
13729+
if (!ShiftOp1)
13730+
return SDValue();
13731+
unsigned ShiftValue = ShiftOp1->getZExtValue();
13732+
13733+
if (ShiftValue < 1 || ShiftValue > ResVT.getScalarSizeInBits())
13734+
return SDValue();
13735+
13736+
SDValue Add = Shift->getOperand(0);
13737+
if (Add->getOpcode() != ISD::ADD || !Add->hasOneUse())
13738+
return SDValue();
13739+
auto AddOp1 =
13740+
dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Add->getOperand(1)));
13741+
if (!AddOp1)
13742+
return SDValue();
13743+
uint64_t AddValue = AddOp1->getZExtValue();
13744+
if (AddValue != 1ULL << (ShiftValue - 1))
13745+
return SDValue();
13746+
13747+
SDLoc DL(Shift);
13748+
unsigned Opc = Shift->getOpcode() == ISD::SRA ? AArch64ISD::SRSHR_I_PRED
13749+
: AArch64ISD::URSHR_I_PRED;
13750+
return DAG.getNode(Opc, DL, ResVT, getPredicateForVector(DAG, DL, ResVT),
13751+
Add->getOperand(0),
13752+
DAG.getTargetConstant(ShiftValue, DL, MVT::i32));
13753+
}
13754+
1371613755
SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
1371713756
SelectionDAG &DAG) const {
1371813757
EVT VT = Op.getValueType();
@@ -13738,6 +13777,10 @@ SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
1373813777
Op.getOperand(0), Op.getOperand(1));
1373913778
case ISD::SRA:
1374013779
case ISD::SRL:
13780+
if (VT.isScalableVector() && Subtarget->hasSVE2orSME())
13781+
if (SDValue RSH = tryLowerToRoundingShiftRightByImm(Op, DAG))
13782+
return RSH;
13783+
1374113784
if (VT.isScalableVector() ||
1374213785
useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable())) {
1374313786
unsigned Opc = Op.getOpcode() == ISD::SRA ? AArch64ISD::SRA_PRED
@@ -20025,6 +20068,12 @@ static SDValue performIntrinsicCombine(SDNode *N,
2002520068
case Intrinsic::aarch64_sve_uqsub_x:
2002620069
return DAG.getNode(ISD::USUBSAT, SDLoc(N), N->getValueType(0),
2002720070
N->getOperand(1), N->getOperand(2));
20071+
case Intrinsic::aarch64_sve_srshr:
20072+
return DAG.getNode(AArch64ISD::SRSHR_I_PRED, SDLoc(N), N->getValueType(0),
20073+
N->getOperand(1), N->getOperand(2), N->getOperand(3));
20074+
case Intrinsic::aarch64_sve_urshr:
20075+
return DAG.getNode(AArch64ISD::URSHR_I_PRED, SDLoc(N), N->getValueType(0),
20076+
N->getOperand(1), N->getOperand(2), N->getOperand(3));
2002820077
case Intrinsic::aarch64_sve_asrd:
2002920078
return DAG.getNode(AArch64ISD::SRAD_MERGE_OP1, SDLoc(N), N->getValueType(0),
2003020079
N->getOperand(1), N->getOperand(2), N->getOperand(3));
@@ -20652,12 +20701,13 @@ static SDValue performUnpackCombine(SDNode *N, SelectionDAG &DAG,
2065220701
// a uzp1 or a truncating store.
2065320702
static SDValue trySimplifySrlAddToRshrnb(SDValue Srl, SelectionDAG &DAG,
2065420703
const AArch64Subtarget *Subtarget) {
20655-
EVT VT = Srl->getValueType(0);
20704+
if (Srl->getOpcode() != AArch64ISD::URSHR_I_PRED)
20705+
return SDValue();
2065620706

20657-
if (!VT.isScalableVector() || !Subtarget->hasSVE2() ||
20658-
Srl->getOpcode() != ISD::SRL)
20707+
if (!isAllActivePredicate(DAG, Srl.getOperand(0)))
2065920708
return SDValue();
2066020709

20710+
EVT VT = Srl->getValueType(0);
2066120711
EVT ResVT;
2066220712
if (VT == MVT::nxv8i16)
2066320713
ResVT = MVT::nxv16i8;
@@ -20668,29 +20718,14 @@ static SDValue trySimplifySrlAddToRshrnb(SDValue Srl, SelectionDAG &DAG,
2066820718
else
2066920719
return SDValue();
2067020720

20671-
auto SrlOp1 =
20672-
dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Srl->getOperand(1)));
20673-
if (!SrlOp1)
20674-
return SDValue();
20675-
unsigned ShiftValue = SrlOp1->getZExtValue();
20676-
if (ShiftValue < 1 || ShiftValue > ResVT.getScalarSizeInBits())
20677-
return SDValue();
20678-
20679-
SDValue Add = Srl->getOperand(0);
20680-
if (Add->getOpcode() != ISD::ADD || !Add->hasOneUse())
20681-
return SDValue();
20682-
auto AddOp1 =
20683-
dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Add->getOperand(1)));
20684-
if (!AddOp1)
20685-
return SDValue();
20686-
uint64_t AddValue = AddOp1->getZExtValue();
20687-
if (AddValue != 1ULL << (ShiftValue - 1))
20721+
unsigned ShiftValue =
20722+
cast<ConstantSDNode>(Srl->getOperand(2))->getZExtValue();
20723+
if (ShiftValue > ResVT.getScalarSizeInBits())
2068820724
return SDValue();
2068920725

2069020726
SDLoc DL(Srl);
20691-
SDValue Rshrnb = DAG.getNode(
20692-
AArch64ISD::RSHRNB_I, DL, ResVT,
20693-
{Add->getOperand(0), DAG.getTargetConstant(ShiftValue, DL, MVT::i32)});
20727+
SDValue Rshrnb = DAG.getNode(AArch64ISD::RSHRNB_I, DL, ResVT,
20728+
{Srl->getOperand(1), Srl->getOperand(2)});
2069420729
return DAG.getNode(ISD::BITCAST, DL, VT, Rshrnb);
2069520730
}
2069620731

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,9 @@ enum NodeType : unsigned {
210210
UQSHL_I,
211211
SQSHLU_I,
212212
SRSHR_I,
213+
SRSHR_I_PRED,
213214
URSHR_I,
215+
URSHR_I_PRED,
214216

215217
// Vector narrowing shift by immediate (bottom)
216218
RSHRNB_I,

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,8 @@ def SDT_AArch64Arith_Imm : SDTypeProfile<1, 3, [
232232
]>;
233233

234234
def AArch64asrd_m1 : SDNode<"AArch64ISD::SRAD_MERGE_OP1", SDT_AArch64Arith_Imm>;
235+
def AArch64urshri_p : SDNode<"AArch64ISD::URSHR_I_PRED", SDT_AArch64Arith_Imm>;
236+
def AArch64srshri_p : SDNode<"AArch64ISD::SRSHR_I_PRED", SDT_AArch64Arith_Imm>;
235237

236238
def SDT_AArch64IntExtend : SDTypeProfile<1, 4, [
237239
SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisVT<3, OtherVT>, SDTCisVec<4>,
@@ -3538,8 +3540,8 @@ let Predicates = [HasSVE2orSME] in {
35383540
// SVE2 predicated shifts
35393541
defm SQSHL_ZPmI : sve_int_bin_pred_shift_imm_left_dup<0b0110, "sqshl", "SQSHL_ZPZI", int_aarch64_sve_sqshl>;
35403542
defm UQSHL_ZPmI : sve_int_bin_pred_shift_imm_left_dup<0b0111, "uqshl", "UQSHL_ZPZI", int_aarch64_sve_uqshl>;
3541-
defm SRSHR_ZPmI : sve_int_bin_pred_shift_imm_right< 0b1100, "srshr", "SRSHR_ZPZI", int_aarch64_sve_srshr>;
3542-
defm URSHR_ZPmI : sve_int_bin_pred_shift_imm_right< 0b1101, "urshr", "URSHR_ZPZI", int_aarch64_sve_urshr>;
3543+
defm SRSHR_ZPmI : sve_int_bin_pred_shift_imm_right< 0b1100, "srshr", "SRSHR_ZPZI", AArch64srshri_p>;
3544+
defm URSHR_ZPmI : sve_int_bin_pred_shift_imm_right< 0b1101, "urshr", "URSHR_ZPZI", AArch64urshri_p>;
35433545
defm SQSHLU_ZPmI : sve_int_bin_pred_shift_imm_left< 0b1111, "sqshlu", "SQSHLU_ZPZI", int_aarch64_sve_sqshlu>;
35443546

35453547
// SVE2 integer add/subtract long
@@ -3583,8 +3585,8 @@ let Predicates = [HasSVE2orSME] in {
35833585
// SVE2 bitwise shift right and accumulate
35843586
defm SSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b00, "ssra", AArch64ssra>;
35853587
defm USRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b01, "usra", AArch64usra>;
3586-
defm SRSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b10, "srsra", int_aarch64_sve_srsra, int_aarch64_sve_srshr>;
3587-
defm URSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b11, "ursra", int_aarch64_sve_ursra, int_aarch64_sve_urshr>;
3588+
defm SRSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b10, "srsra", int_aarch64_sve_srsra, AArch64srshri_p>;
3589+
defm URSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b11, "ursra", int_aarch64_sve_ursra, AArch64urshri_p>;
35883590

35893591
// SVE2 complex integer add
35903592
defm CADD_ZZI : sve2_int_cadd<0b0, "cadd", int_aarch64_sve_cadd_x>;

llvm/test/CodeGen/AArch64/sve2-intrinsics-combine-rshrnb.ll

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -184,16 +184,14 @@ define void @wide_add_shift_add_rshrnb_d(ptr %dest, i64 %index, <vscale x 4 x i6
184184
define void @neg_wide_add_shift_add_rshrnb_d(ptr %dest, i64 %index, <vscale x 4 x i64> %arg1){
185185
; CHECK-LABEL: neg_wide_add_shift_add_rshrnb_d:
186186
; CHECK: // %bb.0:
187-
; CHECK-NEXT: mov z2.d, #0x800000000000
188-
; CHECK-NEXT: ptrue p0.s
189-
; CHECK-NEXT: add z0.d, z0.d, z2.d
190-
; CHECK-NEXT: add z1.d, z1.d, z2.d
191-
; CHECK-NEXT: lsr z1.d, z1.d, #48
192-
; CHECK-NEXT: lsr z0.d, z0.d, #48
187+
; CHECK-NEXT: ptrue p0.d
188+
; CHECK-NEXT: ptrue p1.s
189+
; CHECK-NEXT: urshr z1.d, p0/m, z1.d, #48
190+
; CHECK-NEXT: urshr z0.d, p0/m, z0.d, #48
193191
; CHECK-NEXT: uzp1 z0.s, z0.s, z1.s
194-
; CHECK-NEXT: ld1w { z1.s }, p0/z, [x0, x1, lsl #2]
192+
; CHECK-NEXT: ld1w { z1.s }, p1/z, [x0, x1, lsl #2]
195193
; CHECK-NEXT: add z0.s, z1.s, z0.s
196-
; CHECK-NEXT: st1w { z0.s }, p0, [x0, x1, lsl #2]
194+
; CHECK-NEXT: st1w { z0.s }, p1, [x0, x1, lsl #2]
197195
; CHECK-NEXT: ret
198196
%1 = add <vscale x 4 x i64> %arg1, shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 140737488355328, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer)
199197
%2 = lshr <vscale x 4 x i64> %1, shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> poison, i64 48, i64 0), <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer)
@@ -286,8 +284,7 @@ define void @neg_add_lshr_rshrnb_s(ptr %ptr, ptr %dst, i64 %index){
286284
; CHECK: // %bb.0:
287285
; CHECK-NEXT: ptrue p0.d
288286
; CHECK-NEXT: ld1d { z0.d }, p0/z, [x0]
289-
; CHECK-NEXT: add z0.d, z0.d, #32 // =0x20
290-
; CHECK-NEXT: lsr z0.d, z0.d, #6
287+
; CHECK-NEXT: urshr z0.d, p0/m, z0.d, #6
291288
; CHECK-NEXT: st1h { z0.d }, p0, [x1, x2, lsl #1]
292289
; CHECK-NEXT: ret
293290
%load = load <vscale x 2 x i64>, ptr %ptr, align 2

0 commit comments

Comments
 (0)