@@ -9596,6 +9596,21 @@ SDValue RISCVTargetLowering::lowerVPREDUCE(SDValue Op,
9596
9596
Vec, Mask, VL, DL, DAG, Subtarget);
9597
9597
}
9598
9598
9599
+ /// Returns true if \p LHS is known to be equal to \p RHS, taking into account
9600
+ /// if VLEN is exactly known by \p Subtarget and thus vscale when handling
9601
+ /// scalable quantities.
9602
+ static bool isKnownEQ(ElementCount LHS, ElementCount RHS,
9603
+ const RISCVSubtarget &Subtarget) {
9604
+ if (auto VLen = Subtarget.getRealVLen()) {
9605
+ const unsigned Vscale = *VLen / RISCV::RVVBitsPerBlock;
9606
+ if (LHS.isScalable())
9607
+ LHS = ElementCount::getFixed(LHS.getKnownMinValue() * Vscale);
9608
+ if (RHS.isScalable())
9609
+ RHS = ElementCount::getFixed(RHS.getKnownMinValue() * Vscale);
9610
+ }
9611
+ return LHS == RHS;
9612
+ }
9613
+
9599
9614
SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op,
9600
9615
SelectionDAG &DAG) const {
9601
9616
SDValue Vec = Op.getOperand(0);
@@ -9645,12 +9660,13 @@ SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op,
9645
9660
}
9646
9661
}
9647
9662
9648
- // If the subvector vector is a fixed-length type, we cannot use subregister
9649
- // manipulation to simplify the codegen; we don't know which register of a
9650
- // LMUL group contains the specific subvector as we only know the minimum
9651
- // register size. Therefore we must slide the vector group up the full
9652
- // amount.
9653
- if (SubVecVT.isFixedLengthVector()) {
9663
+ // If the subvector vector is a fixed-length type and we don't know VLEN
9664
+ // exactly, we cannot use subregister manipulation to simplify the codegen; we
9665
+ // don't know which register of a LMUL group contains the specific subvector
9666
+ // as we only know the minimum register size. Therefore we must slide the
9667
+ // vector group up the full amount.
9668
+ const auto VLen = Subtarget.getRealVLen();
9669
+ if (SubVecVT.isFixedLengthVector() && !VLen) {
9654
9670
if (OrigIdx == 0 && Vec.isUndef() && !VecVT.isFixedLengthVector())
9655
9671
return Op;
9656
9672
MVT ContainerVT = VecVT;
@@ -9698,41 +9714,92 @@ SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op,
9698
9714
return DAG.getBitcast(Op.getValueType(), SubVec);
9699
9715
}
9700
9716
9701
- unsigned SubRegIdx, RemIdx;
9702
- std::tie(SubRegIdx, RemIdx) =
9703
- RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
9704
- VecVT, SubVecVT, OrigIdx, TRI);
9717
+ MVT ContainerVecVT = VecVT;
9718
+ if (VecVT.isFixedLengthVector()) {
9719
+ ContainerVecVT = getContainerForFixedLengthVector(VecVT);
9720
+ Vec = convertToScalableVector(ContainerVecVT, Vec, DAG, Subtarget);
9721
+ }
9705
9722
9706
- RISCVII::VLMUL SubVecLMUL = RISCVTargetLowering::getLMUL(SubVecVT);
9723
+ MVT ContainerSubVecVT = SubVecVT;
9724
+ if (SubVecVT.isFixedLengthVector()) {
9725
+ ContainerSubVecVT = getContainerForFixedLengthVector(SubVecVT);
9726
+ SubVec = convertToScalableVector(ContainerSubVecVT, SubVec, DAG, Subtarget);
9727
+ }
9728
+
9729
+ unsigned SubRegIdx;
9730
+ ElementCount RemIdx;
9731
+ // insert_subvector scales the index by vscale if the subvector is scalable,
9732
+ // and decomposeSubvectorInsertExtractToSubRegs takes this into account. So if
9733
+ // we have a fixed length subvector, we need to adjust the index by 1/vscale.
9734
+ if (SubVecVT.isFixedLengthVector()) {
9735
+ assert(VLen);
9736
+ unsigned Vscale = *VLen / RISCV::RVVBitsPerBlock;
9737
+ auto Decompose =
9738
+ RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
9739
+ ContainerVecVT, ContainerSubVecVT, OrigIdx / Vscale, TRI);
9740
+ SubRegIdx = Decompose.first;
9741
+ RemIdx = ElementCount::getFixed((Decompose.second * Vscale) +
9742
+ (OrigIdx % Vscale));
9743
+ } else {
9744
+ auto Decompose =
9745
+ RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
9746
+ ContainerVecVT, ContainerSubVecVT, OrigIdx, TRI);
9747
+ SubRegIdx = Decompose.first;
9748
+ RemIdx = ElementCount::getScalable(Decompose.second);
9749
+ }
9750
+
9751
+ RISCVII::VLMUL SubVecLMUL = RISCVTargetLowering::getLMUL(ContainerSubVecVT);
9707
9752
bool IsSubVecPartReg = SubVecLMUL == RISCVII::VLMUL::LMUL_F2 ||
9708
9753
SubVecLMUL == RISCVII::VLMUL::LMUL_F4 ||
9709
9754
SubVecLMUL == RISCVII::VLMUL::LMUL_F8;
9755
+ bool AlignedToVecReg = !IsSubVecPartReg;
9756
+ if (SubVecVT.isFixedLengthVector())
9757
+ AlignedToVecReg &= SubVecVT.getSizeInBits() ==
9758
+ ContainerSubVecVT.getSizeInBits().getKnownMinValue() *
9759
+ (*VLen / RISCV::RVVBitsPerBlock);
9710
9760
9711
9761
// 1. If the Idx has been completely eliminated and this subvector's size is
9712
9762
// a vector register or a multiple thereof, or the surrounding elements are
9713
9763
// undef, then this is a subvector insert which naturally aligns to a vector
9714
9764
// register. These can easily be handled using subregister manipulation.
9715
- // 2. If the subvector is smaller than a vector register, then the insertion
9716
- // must preserve the undisturbed elements of the register. We do this by
9717
- // lowering to an EXTRACT_SUBVECTOR grabbing the nearest LMUL=1 vector type
9718
- // (which resolves to a subregister copy), performing a VSLIDEUP to place the
9719
- // subvector within the vector register, and an INSERT_SUBVECTOR of that
9765
+ // 2. If the subvector isn't exactly aligned to a vector register group , then
9766
+ // the insertion must preserve the undisturbed elements of the register. We do
9767
+ // this by lowering to an EXTRACT_SUBVECTOR grabbing the nearest LMUL=1 vector
9768
+ // type (which resolves to a subregister copy), performing a VSLIDEUP to place
9769
+ // the subvector within the vector register, and an INSERT_SUBVECTOR of that
9720
9770
// LMUL=1 type back into the larger vector (resolving to another subregister
9721
9771
// operation). See below for how our VSLIDEUP works. We go via a LMUL=1 type
9722
9772
// to avoid allocating a large register group to hold our subvector.
9723
- if (RemIdx == 0 && (!IsSubVecPartReg || Vec.isUndef()))
9773
+ if (RemIdx.isZero() && (AlignedToVecReg || Vec.isUndef())) {
9774
+ if (SubVecVT.isFixedLengthVector()) {
9775
+ // We may get NoSubRegister if inserting at index 0 and the subvec
9776
+ // container is the same as the vector, e.g. vec=v4i32,subvec=v4i32,idx=0
9777
+ if (SubRegIdx == RISCV::NoSubRegister) {
9778
+ assert(OrigIdx == 0);
9779
+ return Op;
9780
+ }
9781
+
9782
+ SDValue Insert =
9783
+ DAG.getTargetInsertSubreg(SubRegIdx, DL, ContainerVecVT, Vec, SubVec);
9784
+ if (VecVT.isFixedLengthVector())
9785
+ Insert = convertFromScalableVector(VecVT, Insert, DAG, Subtarget);
9786
+ return Insert;
9787
+ }
9724
9788
return Op;
9789
+ }
9725
9790
9726
9791
// VSLIDEUP works by leaving elements 0<i<OFFSET undisturbed, elements
9727
9792
// OFFSET<=i<VL set to the "subvector" and vl<=i<VLMAX set to the tail policy
9728
9793
// (in our case undisturbed). This means we can set up a subvector insertion
9729
9794
// where OFFSET is the insertion offset, and the VL is the OFFSET plus the
9730
9795
// size of the subvector.
9731
- MVT InterSubVT = VecVT ;
9796
+ MVT InterSubVT = ContainerVecVT ;
9732
9797
SDValue AlignedExtract = Vec;
9733
- unsigned AlignedIdx = OrigIdx - RemIdx;
9734
- if (VecVT.bitsGT(getLMUL1VT(VecVT))) {
9735
- InterSubVT = getLMUL1VT(VecVT);
9798
+ unsigned AlignedIdx = OrigIdx - RemIdx.getKnownMinValue();
9799
+ if (SubVecVT.isFixedLengthVector())
9800
+ AlignedIdx /= *VLen / RISCV::RVVBitsPerBlock;
9801
+ if (ContainerVecVT.bitsGT(getLMUL1VT(ContainerVecVT))) {
9802
+ InterSubVT = getLMUL1VT(ContainerVecVT);
9736
9803
// Extract a subvector equal to the nearest full vector register type. This
9737
9804
// should resolve to a EXTRACT_SUBREG instruction.
9738
9805
AlignedExtract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InterSubVT, Vec,
@@ -9743,25 +9810,23 @@ SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op,
9743
9810
DAG.getUNDEF(InterSubVT), SubVec,
9744
9811
DAG.getVectorIdxConstant(0, DL));
9745
9812
9746
- auto [Mask, VL] = getDefaultScalableVLOps (VecVT, DL, DAG, Subtarget);
9813
+ auto [Mask, VL] = getDefaultVLOps (VecVT, ContainerVecVT , DL, DAG, Subtarget);
9747
9814
9748
- ElementCount EndIndex =
9749
- ElementCount::getScalable(RemIdx) + SubVecVT.getVectorElementCount();
9750
- VL = computeVLMax(SubVecVT, DL, DAG);
9815
+ ElementCount EndIndex = RemIdx + SubVecVT.getVectorElementCount();
9816
+ VL = DAG.getElementCount(DL, XLenVT, SubVecVT.getVectorElementCount());
9751
9817
9752
9818
// Use tail agnostic policy if we're inserting over InterSubVT's tail.
9753
9819
unsigned Policy = RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED;
9754
- if (EndIndex == InterSubVT.getVectorElementCount())
9820
+ if (isKnownEQ( EndIndex, InterSubVT.getVectorElementCount(), Subtarget ))
9755
9821
Policy = RISCVII::TAIL_AGNOSTIC;
9756
9822
9757
9823
// If we're inserting into the lowest elements, use a tail undisturbed
9758
9824
// vmv.v.v.
9759
- if (RemIdx == 0 ) {
9825
+ if (RemIdx.isZero() ) {
9760
9826
SubVec = DAG.getNode(RISCVISD::VMV_V_V_VL, DL, InterSubVT, AlignedExtract,
9761
9827
SubVec, VL);
9762
9828
} else {
9763
- SDValue SlideupAmt =
9764
- DAG.getVScale(DL, XLenVT, APInt(XLenVT.getSizeInBits(), RemIdx));
9829
+ SDValue SlideupAmt = DAG.getElementCount(DL, XLenVT, RemIdx);
9765
9830
9766
9831
// Construct the vector length corresponding to RemIdx + length(SubVecVT).
9767
9832
VL = DAG.getNode(ISD::ADD, DL, XLenVT, SlideupAmt, VL);
@@ -9772,10 +9837,13 @@ SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op,
9772
9837
9773
9838
// If required, insert this subvector back into the correct vector register.
9774
9839
// This should resolve to an INSERT_SUBREG instruction.
9775
- if (VecVT .bitsGT(InterSubVT))
9776
- SubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VecVT , Vec, SubVec,
9840
+ if (ContainerVecVT .bitsGT(InterSubVT))
9841
+ SubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ContainerVecVT , Vec, SubVec,
9777
9842
DAG.getVectorIdxConstant(AlignedIdx, DL));
9778
9843
9844
+ if (VecVT.isFixedLengthVector())
9845
+ SubVec = convertFromScalableVector(VecVT, SubVec, DAG, Subtarget);
9846
+
9779
9847
// We might have bitcast from a mask type: cast back to the original type if
9780
9848
// required.
9781
9849
return DAG.getBitcast(Op.getSimpleValueType(), SubVec);
0 commit comments