@@ -1358,6 +1358,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
1358
1358
1359
1359
if (!Subtarget->isLittleEndian())
1360
1360
setOperationAction(ISD::BITCAST, VT, Expand);
1361
+
1362
+ if (Subtarget->hasSVE2orSME())
1363
+ // For SLI/SRI.
1364
+ setOperationAction(ISD::OR, VT, Custom);
1361
1365
}
1362
1366
1363
1367
// Illegal unpacked integer vector types.
@@ -5409,15 +5413,18 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
5409
5413
}
5410
5414
5411
5415
case Intrinsic::aarch64_neon_vsri:
5412
- case Intrinsic::aarch64_neon_vsli: {
5416
+ case Intrinsic::aarch64_neon_vsli:
5417
+ case Intrinsic::aarch64_sve_sri:
5418
+ case Intrinsic::aarch64_sve_sli: {
5413
5419
EVT Ty = Op.getValueType();
5414
5420
5415
5421
if (!Ty.isVector())
5416
5422
report_fatal_error("Unexpected type for aarch64_neon_vsli");
5417
5423
5418
5424
assert(Op.getConstantOperandVal(3) <= Ty.getScalarSizeInBits());
5419
5425
5420
- bool IsShiftRight = IntNo == Intrinsic::aarch64_neon_vsri;
5426
+ bool IsShiftRight = IntNo == Intrinsic::aarch64_neon_vsri ||
5427
+ IntNo == Intrinsic::aarch64_sve_sri;
5421
5428
unsigned Opcode = IsShiftRight ? AArch64ISD::VSRI : AArch64ISD::VSLI;
5422
5429
return DAG.getNode(Opcode, dl, Ty, Op.getOperand(1), Op.getOperand(2),
5423
5430
Op.getOperand(3));
@@ -12542,6 +12549,53 @@ static bool isAllConstantBuildVector(const SDValue &PotentialBVec,
12542
12549
return true;
12543
12550
}
12544
12551
12552
+ static bool isAllInactivePredicate(SDValue N) {
12553
+ // Look through cast.
12554
+ while (N.getOpcode() == AArch64ISD::REINTERPRET_CAST)
12555
+ N = N.getOperand(0);
12556
+
12557
+ return ISD::isConstantSplatVectorAllZeros(N.getNode());
12558
+ }
12559
+
12560
+ static bool isAllActivePredicate(SelectionDAG &DAG, SDValue N) {
12561
+ unsigned NumElts = N.getValueType().getVectorMinNumElements();
12562
+
12563
+ // Look through cast.
12564
+ while (N.getOpcode() == AArch64ISD::REINTERPRET_CAST) {
12565
+ N = N.getOperand(0);
12566
+ // When reinterpreting from a type with fewer elements the "new" elements
12567
+ // are not active, so bail if they're likely to be used.
12568
+ if (N.getValueType().getVectorMinNumElements() < NumElts)
12569
+ return false;
12570
+ }
12571
+
12572
+ if (ISD::isConstantSplatVectorAllOnes(N.getNode()))
12573
+ return true;
12574
+
12575
+ // "ptrue p.<ty>, all" can be considered all active when <ty> is the same size
12576
+ // or smaller than the implicit element type represented by N.
12577
+ // NOTE: A larger element count implies a smaller element type.
12578
+ if (N.getOpcode() == AArch64ISD::PTRUE &&
12579
+ N.getConstantOperandVal(0) == AArch64SVEPredPattern::all)
12580
+ return N.getValueType().getVectorMinNumElements() >= NumElts;
12581
+
12582
+ // If we're compiling for a specific vector-length, we can check if the
12583
+ // pattern's VL equals that of the scalable vector at runtime.
12584
+ if (N.getOpcode() == AArch64ISD::PTRUE) {
12585
+ const auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
12586
+ unsigned MinSVESize = Subtarget.getMinSVEVectorSizeInBits();
12587
+ unsigned MaxSVESize = Subtarget.getMaxSVEVectorSizeInBits();
12588
+ if (MaxSVESize && MinSVESize == MaxSVESize) {
12589
+ unsigned VScale = MaxSVESize / AArch64::SVEBitsPerBlock;
12590
+ unsigned PatNumElts =
12591
+ getNumElementsFromSVEPredPattern(N.getConstantOperandVal(0));
12592
+ return PatNumElts == (NumElts * VScale);
12593
+ }
12594
+ }
12595
+
12596
+ return false;
12597
+ }
12598
+
12545
12599
// Attempt to form a vector S[LR]I from (or (and X, BvecC1), (lsl Y, C2)),
12546
12600
// to (SLI X, Y, C2), where X and Y have matching vector types, BvecC1 is a
12547
12601
// BUILD_VECTORs with constant element C1, C2 is a constant, and:
@@ -12567,59 +12621,78 @@ static SDValue tryLowerToSLI(SDNode *N, SelectionDAG &DAG) {
12567
12621
// Is one of the operands an AND or a BICi? The AND may have been optimised to
12568
12622
// a BICi in order to use an immediate instead of a register.
12569
12623
// Is the other operand an shl or lshr? This will have been turned into:
12570
- // AArch64ISD::VSHL vector, #shift or AArch64ISD::VLSHR vector, #shift.
12624
+ // AArch64ISD::VSHL vector, #shift or AArch64ISD::VLSHR vector, #shift
12625
+ // or (AArch64ISD::SHL_PRED || AArch64ISD::SRL_PRED) mask, vector, #shiftVec.
12571
12626
if ((FirstOpc == ISD::AND || FirstOpc == AArch64ISD::BICi) &&
12572
- (SecondOpc == AArch64ISD::VSHL || SecondOpc == AArch64ISD::VLSHR)) {
12627
+ (SecondOpc == AArch64ISD::VSHL || SecondOpc == AArch64ISD::VLSHR ||
12628
+ SecondOpc == AArch64ISD::SHL_PRED ||
12629
+ SecondOpc == AArch64ISD::SRL_PRED)) {
12573
12630
And = FirstOp;
12574
12631
Shift = SecondOp;
12575
12632
12576
12633
} else if ((SecondOpc == ISD::AND || SecondOpc == AArch64ISD::BICi) &&
12577
- (FirstOpc == AArch64ISD::VSHL || FirstOpc == AArch64ISD::VLSHR)) {
12634
+ (FirstOpc == AArch64ISD::VSHL || FirstOpc == AArch64ISD::VLSHR ||
12635
+ FirstOpc == AArch64ISD::SHL_PRED ||
12636
+ FirstOpc == AArch64ISD::SRL_PRED)) {
12578
12637
And = SecondOp;
12579
12638
Shift = FirstOp;
12580
12639
} else
12581
12640
return SDValue();
12582
12641
12583
12642
bool IsAnd = And.getOpcode() == ISD::AND;
12584
- bool IsShiftRight = Shift.getOpcode() == AArch64ISD::VLSHR;
12585
-
12586
- // Is the shift amount constant?
12587
- ConstantSDNode *C2node = dyn_cast<ConstantSDNode>(Shift.getOperand(1));
12588
- if (!C2node)
12643
+ bool IsShiftRight = Shift.getOpcode() == AArch64ISD::VLSHR ||
12644
+ Shift.getOpcode() == AArch64ISD::SRL_PRED;
12645
+ bool ShiftHasPredOp = Shift.getOpcode() == AArch64ISD::SHL_PRED ||
12646
+ Shift.getOpcode() == AArch64ISD::SRL_PRED;
12647
+
12648
+ // Is the shift amount constant and are all lanes active?
12649
+ uint64_t C2;
12650
+ if (ShiftHasPredOp) {
12651
+ if (!isAllActivePredicate(DAG, Shift.getOperand(0)))
12652
+ return SDValue();
12653
+ APInt C;
12654
+ if (!ISD::isConstantSplatVector(Shift.getOperand(2).getNode(), C))
12655
+ return SDValue();
12656
+ C2 = C.getZExtValue();
12657
+ } else if (ConstantSDNode *C2node =
12658
+ dyn_cast<ConstantSDNode>(Shift.getOperand(1)))
12659
+ C2 = C2node->getZExtValue();
12660
+ else
12589
12661
return SDValue();
12590
12662
12591
- uint64_t C1;
12663
+ APInt C1AsAPInt;
12664
+ unsigned ElemSizeInBits = VT.getScalarSizeInBits();
12592
12665
if (IsAnd) {
12593
12666
// Is the and mask vector all constant?
12594
- if (!isAllConstantBuildVector (And.getOperand(1), C1 ))
12667
+ if (!ISD::isConstantSplatVector (And.getOperand(1).getNode(), C1AsAPInt ))
12595
12668
return SDValue();
12596
12669
} else {
12597
12670
// Reconstruct the corresponding AND immediate from the two BICi immediates.
12598
12671
ConstantSDNode *C1nodeImm = dyn_cast<ConstantSDNode>(And.getOperand(1));
12599
12672
ConstantSDNode *C1nodeShift = dyn_cast<ConstantSDNode>(And.getOperand(2));
12600
12673
assert(C1nodeImm && C1nodeShift);
12601
- C1 = ~(C1nodeImm->getZExtValue() << C1nodeShift->getZExtValue());
12674
+ C1AsAPInt = ~(C1nodeImm->getAPIntValue() << C1nodeShift->getAPIntValue());
12675
+ C1AsAPInt = C1AsAPInt.zextOrTrunc(ElemSizeInBits);
12602
12676
}
12603
12677
12604
12678
// Is C1 == ~(Ones(ElemSizeInBits) << C2) or
12605
12679
// C1 == ~(Ones(ElemSizeInBits) >> C2), taking into account
12606
12680
// how much one can shift elements of a particular size?
12607
- uint64_t C2 = C2node->getZExtValue();
12608
- unsigned ElemSizeInBits = VT.getScalarSizeInBits();
12609
12681
if (C2 > ElemSizeInBits)
12610
12682
return SDValue();
12611
12683
12612
- APInt C1AsAPInt(ElemSizeInBits, C1);
12613
12684
APInt RequiredC1 = IsShiftRight ? APInt::getHighBitsSet(ElemSizeInBits, C2)
12614
12685
: APInt::getLowBitsSet(ElemSizeInBits, C2);
12615
12686
if (C1AsAPInt != RequiredC1)
12616
12687
return SDValue();
12617
12688
12618
12689
SDValue X = And.getOperand(0);
12619
- SDValue Y = Shift.getOperand(0);
12690
+ SDValue Y = ShiftHasPredOp ? Shift.getOperand(1) : Shift.getOperand(0);
12691
+ SDValue Imm = ShiftHasPredOp ? DAG.getTargetConstant(C2, DL, MVT::i32)
12692
+ : Shift.getOperand(1);
12620
12693
12621
12694
unsigned Inst = IsShiftRight ? AArch64ISD::VSRI : AArch64ISD::VSLI;
12622
- SDValue ResultSLI = DAG.getNode(Inst, DL, VT, X, Y, Shift.getOperand(1) );
12695
+ SDValue ResultSLI = DAG.getNode(Inst, DL, VT, X, Y, Imm );
12623
12696
12624
12697
LLVM_DEBUG(dbgs() << "aarch64-lower: transformed: \n");
12625
12698
LLVM_DEBUG(N->dump(&DAG));
@@ -12641,6 +12714,8 @@ SDValue AArch64TargetLowering::LowerVectorOR(SDValue Op,
12641
12714
return Res;
12642
12715
12643
12716
EVT VT = Op.getValueType();
12717
+ if (VT.isScalableVector())
12718
+ return Op;
12644
12719
12645
12720
SDValue LHS = Op.getOperand(0);
12646
12721
BuildVectorSDNode *BVN =
@@ -17432,53 +17507,6 @@ static bool isConstantSplatVectorMaskForType(SDNode *N, EVT MemVT) {
17432
17507
return false;
17433
17508
}
17434
17509
17435
- static bool isAllInactivePredicate(SDValue N) {
17436
- // Look through cast.
17437
- while (N.getOpcode() == AArch64ISD::REINTERPRET_CAST)
17438
- N = N.getOperand(0);
17439
-
17440
- return ISD::isConstantSplatVectorAllZeros(N.getNode());
17441
- }
17442
-
17443
- static bool isAllActivePredicate(SelectionDAG &DAG, SDValue N) {
17444
- unsigned NumElts = N.getValueType().getVectorMinNumElements();
17445
-
17446
- // Look through cast.
17447
- while (N.getOpcode() == AArch64ISD::REINTERPRET_CAST) {
17448
- N = N.getOperand(0);
17449
- // When reinterpreting from a type with fewer elements the "new" elements
17450
- // are not active, so bail if they're likely to be used.
17451
- if (N.getValueType().getVectorMinNumElements() < NumElts)
17452
- return false;
17453
- }
17454
-
17455
- if (ISD::isConstantSplatVectorAllOnes(N.getNode()))
17456
- return true;
17457
-
17458
- // "ptrue p.<ty>, all" can be considered all active when <ty> is the same size
17459
- // or smaller than the implicit element type represented by N.
17460
- // NOTE: A larger element count implies a smaller element type.
17461
- if (N.getOpcode() == AArch64ISD::PTRUE &&
17462
- N.getConstantOperandVal(0) == AArch64SVEPredPattern::all)
17463
- return N.getValueType().getVectorMinNumElements() >= NumElts;
17464
-
17465
- // If we're compiling for a specific vector-length, we can check if the
17466
- // pattern's VL equals that of the scalable vector at runtime.
17467
- if (N.getOpcode() == AArch64ISD::PTRUE) {
17468
- const auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
17469
- unsigned MinSVESize = Subtarget.getMinSVEVectorSizeInBits();
17470
- unsigned MaxSVESize = Subtarget.getMaxSVEVectorSizeInBits();
17471
- if (MaxSVESize && MinSVESize == MaxSVESize) {
17472
- unsigned VScale = MaxSVESize / AArch64::SVEBitsPerBlock;
17473
- unsigned PatNumElts =
17474
- getNumElementsFromSVEPredPattern(N.getConstantOperandVal(0));
17475
- return PatNumElts == (NumElts * VScale);
17476
- }
17477
- }
17478
-
17479
- return false;
17480
- }
17481
-
17482
17510
static SDValue performReinterpretCastCombine(SDNode *N) {
17483
17511
SDValue LeafOp = SDValue(N, 0);
17484
17512
SDValue Op = N->getOperand(0);
0 commit comments