Skip to content

Commit 3ea191e

Browse files
committed
[RISCV] Factor repeating code into getMaskTypeFor(VT) [nfc]
1 parent 813e521 commit 3ea191e

File tree

1 file changed

+22
-23
lines changed

1 file changed

+22
-23
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1568,13 +1568,20 @@ static SDValue convertFromScalableVector(EVT VT, SDValue V, SelectionDAG &DAG,
15681568
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, V, Zero);
15691569
}
15701570

1571+
/// Return the type of the mask type suitable for masking the provided
1572+
/// vector type. This is simply an i1 element type vector of the same
1573+
/// (possibly scalable) length.
1574+
static MVT getMaskTypeFor(EVT VecVT) {
1575+
assert(VecVT.isVector());
1576+
ElementCount EC = VecVT.getVectorElementCount();
1577+
return MVT::getVectorVT(MVT::i1, EC);
1578+
}
1579+
15711580
/// Creates an all ones mask suitable for masking a vector of type VecTy with
15721581
/// vector length VL. .
15731582
static SDValue getAllOnesMask(MVT VecVT, SDValue VL, SDLoc DL,
15741583
SelectionDAG &DAG) {
1575-
assert(VecVT.isVector());
1576-
ElementCount EC = VecVT.getVectorElementCount();
1577-
MVT MaskVT = MVT::getVectorVT(MVT::i1, EC);
1584+
MVT MaskVT = getMaskTypeFor(VecVT);
15781585
return DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL);
15791586
}
15801587

@@ -4237,8 +4244,7 @@ SDValue RISCVTargetLowering::lowerVectorTruncLike(SDValue Op,
42374244
ContainerVT = getContainerForFixedLengthVector(SrcVT);
42384245
Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
42394246
if (IsVPTrunc) {
4240-
MVT MaskVT =
4241-
MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
4247+
MVT MaskVT = getMaskTypeFor(ContainerVT);
42424248
Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
42434249
}
42444250
}
@@ -4298,8 +4304,7 @@ SDValue RISCVTargetLowering::lowerVectorFPRoundLike(SDValue Op,
42984304
SrcContainerVT.changeVectorElementType(VT.getVectorElementType());
42994305
Src = convertToScalableVector(SrcContainerVT, Src, DAG, Subtarget);
43004306
if (IsVPFPTrunc) {
4301-
MVT MaskVT =
4302-
MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
4307+
MVT MaskVT = getMaskTypeFor(ContainerVT);
43034308
Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
43044309
}
43054310
}
@@ -4807,7 +4812,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
48074812
DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, DAG.getUNDEF(VT),
48084813
DAG.getConstant(0, DL, MVT::i32), VL);
48094814

4810-
MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorElementCount());
4815+
MVT MaskVT = getMaskTypeFor(VT);
48114816
SDValue Mask = getAllOnesMask(VT, VL, DL, DAG);
48124817
SDValue VID = DAG.getNode(RISCVISD::VID_VL, DL, VT, Mask, VL);
48134818
SDValue SelectCond =
@@ -4841,8 +4846,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
48414846

48424847
SDValue PassThru = Op.getOperand(2);
48434848
if (!IsUnmasked) {
4844-
MVT MaskVT =
4845-
MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
4849+
MVT MaskVT = getMaskTypeFor(ContainerVT);
48464850
Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
48474851
PassThru = convertToScalableVector(ContainerVT, PassThru, DAG, Subtarget);
48484852
}
@@ -4939,8 +4943,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_VOID(SDValue Op,
49394943

49404944
Val = convertToScalableVector(ContainerVT, Val, DAG, Subtarget);
49414945
if (!IsUnmasked) {
4942-
MVT MaskVT =
4943-
MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
4946+
MVT MaskVT = getMaskTypeFor(ContainerVT);
49444947
Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
49454948
}
49464949

@@ -5791,8 +5794,7 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
57915794
ContainerVT = getContainerForFixedLengthVector(VT);
57925795
PassThru = convertToScalableVector(ContainerVT, PassThru, DAG, Subtarget);
57935796
if (!IsUnmasked) {
5794-
MVT MaskVT =
5795-
MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
5797+
MVT MaskVT = getMaskTypeFor(ContainerVT);
57965798
Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
57975799
}
57985800
}
@@ -5858,8 +5860,7 @@ SDValue RISCVTargetLowering::lowerMaskedStore(SDValue Op,
58585860

58595861
Val = convertToScalableVector(ContainerVT, Val, DAG, Subtarget);
58605862
if (!IsUnmasked) {
5861-
MVT MaskVT =
5862-
MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
5863+
MVT MaskVT = getMaskTypeFor(ContainerVT);
58635864
Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
58645865
}
58655866
}
@@ -5897,7 +5898,7 @@ RISCVTargetLowering::lowerFixedLengthVectorSetccToRVV(SDValue Op,
58975898
SDValue VL =
58985899
DAG.getConstant(VT.getVectorNumElements(), DL, Subtarget.getXLenVT());
58995900

5900-
MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
5901+
MVT MaskVT = getMaskTypeFor(ContainerVT);
59015902
SDValue Mask = getAllOnesMask(ContainerVT, VL, DL, DAG);
59025903

59035904
SDValue Cmp = DAG.getNode(RISCVISD::SETCC_VL, DL, MaskVT, Op1, Op2,
@@ -6200,7 +6201,7 @@ SDValue RISCVTargetLowering::lowerVPFPIntConvOp(SDValue Op, SelectionDAG &DAG,
62006201
DstVT = getContainerForFixedLengthVector(DstVT);
62016202
SrcVT = getContainerForFixedLengthVector(SrcVT);
62026203
Src = convertToScalableVector(SrcVT, Src, DAG, Subtarget);
6203-
MVT MaskVT = MVT::getVectorVT(MVT::i1, DstVT.getVectorElementCount());
6204+
MVT MaskVT = getMaskTypeFor(DstVT);
62046205
Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
62056206
}
62066207

@@ -6413,8 +6414,7 @@ SDValue RISCVTargetLowering::lowerMaskedGather(SDValue Op,
64136414
Index = convertToScalableVector(IndexVT, Index, DAG, Subtarget);
64146415

64156416
if (!IsUnmasked) {
6416-
MVT MaskVT =
6417-
MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
6417+
MVT MaskVT = getMaskTypeFor(ContainerVT);
64186418
Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
64196419
PassThru = convertToScalableVector(ContainerVT, PassThru, DAG, Subtarget);
64206420
}
@@ -6525,8 +6525,7 @@ SDValue RISCVTargetLowering::lowerMaskedScatter(SDValue Op,
65256525
Val = convertToScalableVector(ContainerVT, Val, DAG, Subtarget);
65266526

65276527
if (!IsUnmasked) {
6528-
MVT MaskVT =
6529-
MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount());
6528+
MVT MaskVT = getMaskTypeFor(ContainerVT);
65306529
Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
65316530
}
65326531
}
@@ -8813,7 +8812,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
88138812
// The memory VT and the element type must match.
88148813
if (VecVT.getVectorElementType() == MemVT) {
88158814
SDLoc DL(N);
8816-
MVT MaskVT = MVT::getVectorVT(MVT::i1, VecVT.getVectorElementCount());
8815+
MVT MaskVT = getMaskTypeFor(VecVT);
88178816
return DAG.getStoreVP(
88188817
Store->getChain(), DL, Src, Store->getBasePtr(), Store->getOffset(),
88198818
DAG.getConstant(1, DL, MaskVT),

0 commit comments

Comments
 (0)