Skip to content

Commit 792fa23

Browse files
authored
[AArch64][SVE2] Lower OR to SLI/SRI (#77555)
Code builds on NEON code and the tests are adapted from NEON tests minus the tests for illegal types.
1 parent 9f8c818 commit 792fa23

File tree

3 files changed

+358
-67
lines changed

3 files changed

+358
-67
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 93 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,6 +1358,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
13581358

13591359
if (!Subtarget->isLittleEndian())
13601360
setOperationAction(ISD::BITCAST, VT, Expand);
1361+
1362+
if (Subtarget->hasSVE2orSME())
1363+
// For SLI/SRI.
1364+
setOperationAction(ISD::OR, VT, Custom);
13611365
}
13621366

13631367
// Illegal unpacked integer vector types.
@@ -5409,15 +5413,18 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
54095413
}
54105414

54115415
case Intrinsic::aarch64_neon_vsri:
5412-
case Intrinsic::aarch64_neon_vsli: {
5416+
case Intrinsic::aarch64_neon_vsli:
5417+
case Intrinsic::aarch64_sve_sri:
5418+
case Intrinsic::aarch64_sve_sli: {
54135419
EVT Ty = Op.getValueType();
54145420

54155421
if (!Ty.isVector())
54165422
report_fatal_error("Unexpected type for aarch64_neon_vsli");
54175423

54185424
assert(Op.getConstantOperandVal(3) <= Ty.getScalarSizeInBits());
54195425

5420-
bool IsShiftRight = IntNo == Intrinsic::aarch64_neon_vsri;
5426+
bool IsShiftRight = IntNo == Intrinsic::aarch64_neon_vsri ||
5427+
IntNo == Intrinsic::aarch64_sve_sri;
54215428
unsigned Opcode = IsShiftRight ? AArch64ISD::VSRI : AArch64ISD::VSLI;
54225429
return DAG.getNode(Opcode, dl, Ty, Op.getOperand(1), Op.getOperand(2),
54235430
Op.getOperand(3));
@@ -12542,6 +12549,53 @@ static bool isAllConstantBuildVector(const SDValue &PotentialBVec,
1254212549
return true;
1254312550
}
1254412551

12552+
static bool isAllInactivePredicate(SDValue N) {
12553+
// Look through cast.
12554+
while (N.getOpcode() == AArch64ISD::REINTERPRET_CAST)
12555+
N = N.getOperand(0);
12556+
12557+
return ISD::isConstantSplatVectorAllZeros(N.getNode());
12558+
}
12559+
12560+
static bool isAllActivePredicate(SelectionDAG &DAG, SDValue N) {
12561+
unsigned NumElts = N.getValueType().getVectorMinNumElements();
12562+
12563+
// Look through cast.
12564+
while (N.getOpcode() == AArch64ISD::REINTERPRET_CAST) {
12565+
N = N.getOperand(0);
12566+
// When reinterpreting from a type with fewer elements the "new" elements
12567+
// are not active, so bail if they're likely to be used.
12568+
if (N.getValueType().getVectorMinNumElements() < NumElts)
12569+
return false;
12570+
}
12571+
12572+
if (ISD::isConstantSplatVectorAllOnes(N.getNode()))
12573+
return true;
12574+
12575+
// "ptrue p.<ty>, all" can be considered all active when <ty> is the same size
12576+
// or smaller than the implicit element type represented by N.
12577+
// NOTE: A larger element count implies a smaller element type.
12578+
if (N.getOpcode() == AArch64ISD::PTRUE &&
12579+
N.getConstantOperandVal(0) == AArch64SVEPredPattern::all)
12580+
return N.getValueType().getVectorMinNumElements() >= NumElts;
12581+
12582+
// If we're compiling for a specific vector-length, we can check if the
12583+
// pattern's VL equals that of the scalable vector at runtime.
12584+
if (N.getOpcode() == AArch64ISD::PTRUE) {
12585+
const auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
12586+
unsigned MinSVESize = Subtarget.getMinSVEVectorSizeInBits();
12587+
unsigned MaxSVESize = Subtarget.getMaxSVEVectorSizeInBits();
12588+
if (MaxSVESize && MinSVESize == MaxSVESize) {
12589+
unsigned VScale = MaxSVESize / AArch64::SVEBitsPerBlock;
12590+
unsigned PatNumElts =
12591+
getNumElementsFromSVEPredPattern(N.getConstantOperandVal(0));
12592+
return PatNumElts == (NumElts * VScale);
12593+
}
12594+
}
12595+
12596+
return false;
12597+
}
12598+
1254512599
// Attempt to form a vector S[LR]I from (or (and X, BvecC1), (lsl Y, C2)),
1254612600
// to (SLI X, Y, C2), where X and Y have matching vector types, BvecC1 is a
1254712601
// BUILD_VECTORs with constant element C1, C2 is a constant, and:
@@ -12567,59 +12621,78 @@ static SDValue tryLowerToSLI(SDNode *N, SelectionDAG &DAG) {
1256712621
// Is one of the operands an AND or a BICi? The AND may have been optimised to
1256812622
// a BICi in order to use an immediate instead of a register.
1256912623
// Is the other operand an shl or lshr? This will have been turned into:
12570-
// AArch64ISD::VSHL vector, #shift or AArch64ISD::VLSHR vector, #shift.
12624+
// AArch64ISD::VSHL vector, #shift or AArch64ISD::VLSHR vector, #shift
12625+
// or (AArch64ISD::SHL_PRED || AArch64ISD::SRL_PRED) mask, vector, #shiftVec.
1257112626
if ((FirstOpc == ISD::AND || FirstOpc == AArch64ISD::BICi) &&
12572-
(SecondOpc == AArch64ISD::VSHL || SecondOpc == AArch64ISD::VLSHR)) {
12627+
(SecondOpc == AArch64ISD::VSHL || SecondOpc == AArch64ISD::VLSHR ||
12628+
SecondOpc == AArch64ISD::SHL_PRED ||
12629+
SecondOpc == AArch64ISD::SRL_PRED)) {
1257312630
And = FirstOp;
1257412631
Shift = SecondOp;
1257512632

1257612633
} else if ((SecondOpc == ISD::AND || SecondOpc == AArch64ISD::BICi) &&
12577-
(FirstOpc == AArch64ISD::VSHL || FirstOpc == AArch64ISD::VLSHR)) {
12634+
(FirstOpc == AArch64ISD::VSHL || FirstOpc == AArch64ISD::VLSHR ||
12635+
FirstOpc == AArch64ISD::SHL_PRED ||
12636+
FirstOpc == AArch64ISD::SRL_PRED)) {
1257812637
And = SecondOp;
1257912638
Shift = FirstOp;
1258012639
} else
1258112640
return SDValue();
1258212641

1258312642
bool IsAnd = And.getOpcode() == ISD::AND;
12584-
bool IsShiftRight = Shift.getOpcode() == AArch64ISD::VLSHR;
12585-
12586-
// Is the shift amount constant?
12587-
ConstantSDNode *C2node = dyn_cast<ConstantSDNode>(Shift.getOperand(1));
12588-
if (!C2node)
12643+
bool IsShiftRight = Shift.getOpcode() == AArch64ISD::VLSHR ||
12644+
Shift.getOpcode() == AArch64ISD::SRL_PRED;
12645+
bool ShiftHasPredOp = Shift.getOpcode() == AArch64ISD::SHL_PRED ||
12646+
Shift.getOpcode() == AArch64ISD::SRL_PRED;
12647+
12648+
// Is the shift amount constant and are all lanes active?
12649+
uint64_t C2;
12650+
if (ShiftHasPredOp) {
12651+
if (!isAllActivePredicate(DAG, Shift.getOperand(0)))
12652+
return SDValue();
12653+
APInt C;
12654+
if (!ISD::isConstantSplatVector(Shift.getOperand(2).getNode(), C))
12655+
return SDValue();
12656+
C2 = C.getZExtValue();
12657+
} else if (ConstantSDNode *C2node =
12658+
dyn_cast<ConstantSDNode>(Shift.getOperand(1)))
12659+
C2 = C2node->getZExtValue();
12660+
else
1258912661
return SDValue();
1259012662

12591-
uint64_t C1;
12663+
APInt C1AsAPInt;
12664+
unsigned ElemSizeInBits = VT.getScalarSizeInBits();
1259212665
if (IsAnd) {
1259312666
// Is the and mask vector all constant?
12594-
if (!isAllConstantBuildVector(And.getOperand(1), C1))
12667+
if (!ISD::isConstantSplatVector(And.getOperand(1).getNode(), C1AsAPInt))
1259512668
return SDValue();
1259612669
} else {
1259712670
// Reconstruct the corresponding AND immediate from the two BICi immediates.
1259812671
ConstantSDNode *C1nodeImm = dyn_cast<ConstantSDNode>(And.getOperand(1));
1259912672
ConstantSDNode *C1nodeShift = dyn_cast<ConstantSDNode>(And.getOperand(2));
1260012673
assert(C1nodeImm && C1nodeShift);
12601-
C1 = ~(C1nodeImm->getZExtValue() << C1nodeShift->getZExtValue());
12674+
C1AsAPInt = ~(C1nodeImm->getAPIntValue() << C1nodeShift->getAPIntValue());
12675+
C1AsAPInt = C1AsAPInt.zextOrTrunc(ElemSizeInBits);
1260212676
}
1260312677

1260412678
// Is C1 == ~(Ones(ElemSizeInBits) << C2) or
1260512679
// C1 == ~(Ones(ElemSizeInBits) >> C2), taking into account
1260612680
// how much one can shift elements of a particular size?
12607-
uint64_t C2 = C2node->getZExtValue();
12608-
unsigned ElemSizeInBits = VT.getScalarSizeInBits();
1260912681
if (C2 > ElemSizeInBits)
1261012682
return SDValue();
1261112683

12612-
APInt C1AsAPInt(ElemSizeInBits, C1);
1261312684
APInt RequiredC1 = IsShiftRight ? APInt::getHighBitsSet(ElemSizeInBits, C2)
1261412685
: APInt::getLowBitsSet(ElemSizeInBits, C2);
1261512686
if (C1AsAPInt != RequiredC1)
1261612687
return SDValue();
1261712688

1261812689
SDValue X = And.getOperand(0);
12619-
SDValue Y = Shift.getOperand(0);
12690+
SDValue Y = ShiftHasPredOp ? Shift.getOperand(1) : Shift.getOperand(0);
12691+
SDValue Imm = ShiftHasPredOp ? DAG.getTargetConstant(C2, DL, MVT::i32)
12692+
: Shift.getOperand(1);
1262012693

1262112694
unsigned Inst = IsShiftRight ? AArch64ISD::VSRI : AArch64ISD::VSLI;
12622-
SDValue ResultSLI = DAG.getNode(Inst, DL, VT, X, Y, Shift.getOperand(1));
12695+
SDValue ResultSLI = DAG.getNode(Inst, DL, VT, X, Y, Imm);
1262312696

1262412697
LLVM_DEBUG(dbgs() << "aarch64-lower: transformed: \n");
1262512698
LLVM_DEBUG(N->dump(&DAG));
@@ -12641,6 +12714,8 @@ SDValue AArch64TargetLowering::LowerVectorOR(SDValue Op,
1264112714
return Res;
1264212715

1264312716
EVT VT = Op.getValueType();
12717+
if (VT.isScalableVector())
12718+
return Op;
1264412719

1264512720
SDValue LHS = Op.getOperand(0);
1264612721
BuildVectorSDNode *BVN =
@@ -17432,53 +17507,6 @@ static bool isConstantSplatVectorMaskForType(SDNode *N, EVT MemVT) {
1743217507
return false;
1743317508
}
1743417509

17435-
static bool isAllInactivePredicate(SDValue N) {
17436-
// Look through cast.
17437-
while (N.getOpcode() == AArch64ISD::REINTERPRET_CAST)
17438-
N = N.getOperand(0);
17439-
17440-
return ISD::isConstantSplatVectorAllZeros(N.getNode());
17441-
}
17442-
17443-
static bool isAllActivePredicate(SelectionDAG &DAG, SDValue N) {
17444-
unsigned NumElts = N.getValueType().getVectorMinNumElements();
17445-
17446-
// Look through cast.
17447-
while (N.getOpcode() == AArch64ISD::REINTERPRET_CAST) {
17448-
N = N.getOperand(0);
17449-
// When reinterpreting from a type with fewer elements the "new" elements
17450-
// are not active, so bail if they're likely to be used.
17451-
if (N.getValueType().getVectorMinNumElements() < NumElts)
17452-
return false;
17453-
}
17454-
17455-
if (ISD::isConstantSplatVectorAllOnes(N.getNode()))
17456-
return true;
17457-
17458-
// "ptrue p.<ty>, all" can be considered all active when <ty> is the same size
17459-
// or smaller than the implicit element type represented by N.
17460-
// NOTE: A larger element count implies a smaller element type.
17461-
if (N.getOpcode() == AArch64ISD::PTRUE &&
17462-
N.getConstantOperandVal(0) == AArch64SVEPredPattern::all)
17463-
return N.getValueType().getVectorMinNumElements() >= NumElts;
17464-
17465-
// If we're compiling for a specific vector-length, we can check if the
17466-
// pattern's VL equals that of the scalable vector at runtime.
17467-
if (N.getOpcode() == AArch64ISD::PTRUE) {
17468-
const auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
17469-
unsigned MinSVESize = Subtarget.getMinSVEVectorSizeInBits();
17470-
unsigned MaxSVESize = Subtarget.getMaxSVEVectorSizeInBits();
17471-
if (MaxSVESize && MinSVESize == MaxSVESize) {
17472-
unsigned VScale = MaxSVESize / AArch64::SVEBitsPerBlock;
17473-
unsigned PatNumElts =
17474-
getNumElementsFromSVEPredPattern(N.getConstantOperandVal(0));
17475-
return PatNumElts == (NumElts * VScale);
17476-
}
17477-
}
17478-
17479-
return false;
17480-
}
17481-
1748217510
static SDValue performReinterpretCastCombine(SDNode *N) {
1748317511
SDValue LeafOp = SDValue(N, 0);
1748417512
SDValue Op = N->getOperand(0);

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3577,8 +3577,8 @@ let Predicates = [HasSVE2orSME] in {
35773577
defm PMULLT_ZZZ : sve2_pmul_long<0b1, "pmullt", int_aarch64_sve_pmullt_pair>;
35783578

35793579
// SVE2 bitwise shift and insert
3580-
defm SRI_ZZI : sve2_int_bin_shift_imm_right<0b0, "sri", int_aarch64_sve_sri>;
3581-
defm SLI_ZZI : sve2_int_bin_shift_imm_left< 0b1, "sli", int_aarch64_sve_sli>;
3580+
defm SRI_ZZI : sve2_int_bin_shift_imm_right<0b0, "sri", AArch64vsri>;
3581+
defm SLI_ZZI : sve2_int_bin_shift_imm_left< 0b1, "sli", AArch64vsli>;
35823582

35833583
// SVE2 bitwise shift right and accumulate
35843584
defm SSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b00, "ssra", AArch64ssra>;

0 commit comments

Comments
 (0)