@@ -9720,6 +9720,21 @@ SDValue RISCVTargetLowering::lowerVPREDUCE(SDValue Op,
9720
9720
Vec, Mask, VL, DL, DAG, Subtarget);
9721
9721
}
9722
9722
9723
+ /// Returns true if \p LHS is known to be equal to \p RHS, taking into account
9724
+ /// if VLEN is exactly known by \p Subtarget and thus vscale when handling
9725
+ /// scalable quantities.
9726
+ static bool isKnownEQ(ElementCount LHS, ElementCount RHS,
9727
+ const RISCVSubtarget &Subtarget) {
9728
+ if (auto VLen = Subtarget.getRealVLen()) {
9729
+ const unsigned Vscale = *VLen / RISCV::RVVBitsPerBlock;
9730
+ if (LHS.isScalable())
9731
+ LHS = ElementCount::getFixed(LHS.getKnownMinValue() * Vscale);
9732
+ if (RHS.isScalable())
9733
+ RHS = ElementCount::getFixed(RHS.getKnownMinValue() * Vscale);
9734
+ }
9735
+ return LHS == RHS;
9736
+ }
9737
+
9723
9738
SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op,
9724
9739
SelectionDAG &DAG) const {
9725
9740
SDValue Vec = Op.getOperand(0);
@@ -9769,12 +9784,13 @@ SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op,
9769
9784
}
9770
9785
}
9771
9786
9772
- // If the subvector vector is a fixed-length type, we cannot use subregister
9773
- // manipulation to simplify the codegen; we don't know which register of a
9774
- // LMUL group contains the specific subvector as we only know the minimum
9775
- // register size. Therefore we must slide the vector group up the full
9776
- // amount.
9777
- if (SubVecVT.isFixedLengthVector()) {
9787
+ // If the subvector vector is a fixed-length type and we don't know VLEN
9788
+ // exactly, we cannot use subregister manipulation to simplify the codegen; we
9789
+ // don't know which register of a LMUL group contains the specific subvector
9790
+ // as we only know the minimum register size. Therefore we must slide the
9791
+ // vector group up the full amount.
9792
+ const auto VLen = Subtarget.getRealVLen();
9793
+ if (SubVecVT.isFixedLengthVector() && !VLen) {
9778
9794
if (OrigIdx == 0 && Vec.isUndef() && !VecVT.isFixedLengthVector())
9779
9795
return Op;
9780
9796
MVT ContainerVT = VecVT;
@@ -9822,41 +9838,92 @@ SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op,
9822
9838
return DAG.getBitcast(Op.getValueType(), SubVec);
9823
9839
}
9824
9840
9825
- unsigned SubRegIdx, RemIdx;
9826
- std::tie(SubRegIdx, RemIdx) =
9827
- RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
9828
- VecVT, SubVecVT, OrigIdx, TRI);
9841
+ MVT ContainerVecVT = VecVT;
9842
+ if (VecVT.isFixedLengthVector()) {
9843
+ ContainerVecVT = getContainerForFixedLengthVector(VecVT);
9844
+ Vec = convertToScalableVector(ContainerVecVT, Vec, DAG, Subtarget);
9845
+ }
9829
9846
9830
- RISCVII::VLMUL SubVecLMUL = RISCVTargetLowering::getLMUL(SubVecVT);
9847
+ MVT ContainerSubVecVT = SubVecVT;
9848
+ if (SubVecVT.isFixedLengthVector()) {
9849
+ ContainerSubVecVT = getContainerForFixedLengthVector(SubVecVT);
9850
+ SubVec = convertToScalableVector(ContainerSubVecVT, SubVec, DAG, Subtarget);
9851
+ }
9852
+
9853
+ unsigned SubRegIdx;
9854
+ ElementCount RemIdx;
9855
+ // insert_subvector scales the index by vscale if the subvector is scalable,
9856
+ // and decomposeSubvectorInsertExtractToSubRegs takes this into account. So if
9857
+ // we have a fixed length subvector, we need to adjust the index by 1/vscale.
9858
+ if (SubVecVT.isFixedLengthVector()) {
9859
+ assert(VLen);
9860
+ unsigned Vscale = *VLen / RISCV::RVVBitsPerBlock;
9861
+ auto Decompose =
9862
+ RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
9863
+ ContainerVecVT, ContainerSubVecVT, OrigIdx / Vscale, TRI);
9864
+ SubRegIdx = Decompose.first;
9865
+ RemIdx = ElementCount::getFixed((Decompose.second * Vscale) +
9866
+ (OrigIdx % Vscale));
9867
+ } else {
9868
+ auto Decompose =
9869
+ RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
9870
+ ContainerVecVT, ContainerSubVecVT, OrigIdx, TRI);
9871
+ SubRegIdx = Decompose.first;
9872
+ RemIdx = ElementCount::getScalable(Decompose.second);
9873
+ }
9874
+
9875
+ RISCVII::VLMUL SubVecLMUL = RISCVTargetLowering::getLMUL(ContainerSubVecVT);
9831
9876
bool IsSubVecPartReg = SubVecLMUL == RISCVII::VLMUL::LMUL_F2 ||
9832
9877
SubVecLMUL == RISCVII::VLMUL::LMUL_F4 ||
9833
9878
SubVecLMUL == RISCVII::VLMUL::LMUL_F8;
9879
+ bool AlignedToVecReg = !IsSubVecPartReg;
9880
+ if (SubVecVT.isFixedLengthVector())
9881
+ AlignedToVecReg &= SubVecVT.getSizeInBits() ==
9882
+ ContainerSubVecVT.getSizeInBits().getKnownMinValue() *
9883
+ (*VLen / RISCV::RVVBitsPerBlock);
9834
9884
9835
9885
// 1. If the Idx has been completely eliminated and this subvector's size is
9836
9886
// a vector register or a multiple thereof, or the surrounding elements are
9837
9887
// undef, then this is a subvector insert which naturally aligns to a vector
9838
9888
// register. These can easily be handled using subregister manipulation.
9839
- // 2. If the subvector is smaller than a vector register, then the insertion
9840
- // must preserve the undisturbed elements of the register. We do this by
9841
- // lowering to an EXTRACT_SUBVECTOR grabbing the nearest LMUL=1 vector type
9842
- // (which resolves to a subregister copy), performing a VSLIDEUP to place the
9843
- // subvector within the vector register, and an INSERT_SUBVECTOR of that
9889
+ // 2. If the subvector isn't exactly aligned to a vector register group , then
9890
+ // the insertion must preserve the undisturbed elements of the register. We do
9891
+ // this by lowering to an EXTRACT_SUBVECTOR grabbing the nearest LMUL=1 vector
9892
+ // type (which resolves to a subregister copy), performing a VSLIDEUP to place
9893
+ // the subvector within the vector register, and an INSERT_SUBVECTOR of that
9844
9894
// LMUL=1 type back into the larger vector (resolving to another subregister
9845
9895
// operation). See below for how our VSLIDEUP works. We go via a LMUL=1 type
9846
9896
// to avoid allocating a large register group to hold our subvector.
9847
- if (RemIdx == 0 && (!IsSubVecPartReg || Vec.isUndef()))
9897
+ if (RemIdx.isZero() && (AlignedToVecReg || Vec.isUndef())) {
9898
+ if (SubVecVT.isFixedLengthVector()) {
9899
+ // We may get NoSubRegister if inserting at index 0 and the subvec
9900
+ // container is the same as the vector, e.g. vec=v4i32,subvec=v4i32,idx=0
9901
+ if (SubRegIdx == RISCV::NoSubRegister) {
9902
+ assert(OrigIdx == 0);
9903
+ return Op;
9904
+ }
9905
+
9906
+ SDValue Insert =
9907
+ DAG.getTargetInsertSubreg(SubRegIdx, DL, ContainerVecVT, Vec, SubVec);
9908
+ if (VecVT.isFixedLengthVector())
9909
+ Insert = convertFromScalableVector(VecVT, Insert, DAG, Subtarget);
9910
+ return Insert;
9911
+ }
9848
9912
return Op;
9913
+ }
9849
9914
9850
9915
// VSLIDEUP works by leaving elements 0<i<OFFSET undisturbed, elements
9851
9916
// OFFSET<=i<VL set to the "subvector" and vl<=i<VLMAX set to the tail policy
9852
9917
// (in our case undisturbed). This means we can set up a subvector insertion
9853
9918
// where OFFSET is the insertion offset, and the VL is the OFFSET plus the
9854
9919
// size of the subvector.
9855
- MVT InterSubVT = VecVT ;
9920
+ MVT InterSubVT = ContainerVecVT ;
9856
9921
SDValue AlignedExtract = Vec;
9857
- unsigned AlignedIdx = OrigIdx - RemIdx;
9858
- if (VecVT.bitsGT(getLMUL1VT(VecVT))) {
9859
- InterSubVT = getLMUL1VT(VecVT);
9922
+ unsigned AlignedIdx = OrigIdx - RemIdx.getKnownMinValue();
9923
+ if (SubVecVT.isFixedLengthVector())
9924
+ AlignedIdx /= *VLen / RISCV::RVVBitsPerBlock;
9925
+ if (ContainerVecVT.bitsGT(getLMUL1VT(ContainerVecVT))) {
9926
+ InterSubVT = getLMUL1VT(ContainerVecVT);
9860
9927
// Extract a subvector equal to the nearest full vector register type. This
9861
9928
// should resolve to a EXTRACT_SUBREG instruction.
9862
9929
AlignedExtract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InterSubVT, Vec,
@@ -9867,25 +9934,23 @@ SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op,
9867
9934
DAG.getUNDEF(InterSubVT), SubVec,
9868
9935
DAG.getVectorIdxConstant(0, DL));
9869
9936
9870
- auto [Mask, VL] = getDefaultScalableVLOps (VecVT, DL, DAG, Subtarget);
9937
+ auto [Mask, VL] = getDefaultVLOps (VecVT, ContainerVecVT , DL, DAG, Subtarget);
9871
9938
9872
- ElementCount EndIndex =
9873
- ElementCount::getScalable(RemIdx) + SubVecVT.getVectorElementCount();
9874
- VL = computeVLMax(SubVecVT, DL, DAG);
9939
+ ElementCount EndIndex = RemIdx + SubVecVT.getVectorElementCount();
9940
+ VL = DAG.getElementCount(DL, XLenVT, SubVecVT.getVectorElementCount());
9875
9941
9876
9942
// Use tail agnostic policy if we're inserting over InterSubVT's tail.
9877
9943
unsigned Policy = RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED;
9878
- if (EndIndex == InterSubVT.getVectorElementCount())
9944
+ if (isKnownEQ( EndIndex, InterSubVT.getVectorElementCount(), Subtarget ))
9879
9945
Policy = RISCVII::TAIL_AGNOSTIC;
9880
9946
9881
9947
// If we're inserting into the lowest elements, use a tail undisturbed
9882
9948
// vmv.v.v.
9883
- if (RemIdx == 0 ) {
9949
+ if (RemIdx.isZero() ) {
9884
9950
SubVec = DAG.getNode(RISCVISD::VMV_V_V_VL, DL, InterSubVT, AlignedExtract,
9885
9951
SubVec, VL);
9886
9952
} else {
9887
- SDValue SlideupAmt =
9888
- DAG.getVScale(DL, XLenVT, APInt(XLenVT.getSizeInBits(), RemIdx));
9953
+ SDValue SlideupAmt = DAG.getElementCount(DL, XLenVT, RemIdx);
9889
9954
9890
9955
// Construct the vector length corresponding to RemIdx + length(SubVecVT).
9891
9956
VL = DAG.getNode(ISD::ADD, DL, XLenVT, SlideupAmt, VL);
@@ -9896,10 +9961,13 @@ SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op,
9896
9961
9897
9962
// If required, insert this subvector back into the correct vector register.
9898
9963
// This should resolve to an INSERT_SUBREG instruction.
9899
- if (VecVT .bitsGT(InterSubVT))
9900
- SubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VecVT , Vec, SubVec,
9964
+ if (ContainerVecVT .bitsGT(InterSubVT))
9965
+ SubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ContainerVecVT , Vec, SubVec,
9901
9966
DAG.getVectorIdxConstant(AlignedIdx, DL));
9902
9967
9968
+ if (VecVT.isFixedLengthVector())
9969
+ SubVec = convertFromScalableVector(VecVT, SubVec, DAG, Subtarget);
9970
+
9903
9971
// We might have bitcast from a mask type: cast back to the original type if
9904
9972
// required.
9905
9973
return DAG.getBitcast(Op.getSimpleValueType(), SubVec);
0 commit comments