@@ -1568,13 +1568,20 @@ static SDValue convertFromScalableVector(EVT VT, SDValue V, SelectionDAG &DAG,
1568
1568
return DAG.getNode (ISD::EXTRACT_SUBVECTOR, DL, VT, V, Zero);
1569
1569
}
1570
1570
1571
+ // / Return the type of the mask type suitable for masking the provided
1572
+ // / vector type. This is simply an i1 element type vector of the same
1573
+ // / (possibly scalable) length.
1574
+ static MVT getMaskTypeFor (EVT VecVT) {
1575
+ assert (VecVT.isVector ());
1576
+ ElementCount EC = VecVT.getVectorElementCount ();
1577
+ return MVT::getVectorVT (MVT::i1, EC);
1578
+ }
1579
+
1571
1580
// / Creates an all ones mask suitable for masking a vector of type VecTy with
1572
1581
// / vector length VL. .
1573
1582
static SDValue getAllOnesMask (MVT VecVT, SDValue VL, SDLoc DL,
1574
1583
SelectionDAG &DAG) {
1575
- assert (VecVT.isVector ());
1576
- ElementCount EC = VecVT.getVectorElementCount ();
1577
- MVT MaskVT = MVT::getVectorVT (MVT::i1, EC);
1584
+ MVT MaskVT = getMaskTypeFor (VecVT);
1578
1585
return DAG.getNode (RISCVISD::VMSET_VL, DL, MaskVT, VL);
1579
1586
}
1580
1587
@@ -4237,8 +4244,7 @@ SDValue RISCVTargetLowering::lowerVectorTruncLike(SDValue Op,
4237
4244
ContainerVT = getContainerForFixedLengthVector (SrcVT);
4238
4245
Src = convertToScalableVector (ContainerVT, Src, DAG, Subtarget);
4239
4246
if (IsVPTrunc) {
4240
- MVT MaskVT =
4241
- MVT::getVectorVT (MVT::i1, ContainerVT.getVectorElementCount ());
4247
+ MVT MaskVT = getMaskTypeFor (ContainerVT);
4242
4248
Mask = convertToScalableVector (MaskVT, Mask, DAG, Subtarget);
4243
4249
}
4244
4250
}
@@ -4298,8 +4304,7 @@ SDValue RISCVTargetLowering::lowerVectorFPRoundLike(SDValue Op,
4298
4304
SrcContainerVT.changeVectorElementType (VT.getVectorElementType ());
4299
4305
Src = convertToScalableVector (SrcContainerVT, Src, DAG, Subtarget);
4300
4306
if (IsVPFPTrunc) {
4301
- MVT MaskVT =
4302
- MVT::getVectorVT (MVT::i1, ContainerVT.getVectorElementCount ());
4307
+ MVT MaskVT = getMaskTypeFor (ContainerVT);
4303
4308
Mask = convertToScalableVector (MaskVT, Mask, DAG, Subtarget);
4304
4309
}
4305
4310
}
@@ -4807,7 +4812,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
4807
4812
DAG.getNode (RISCVISD::VMV_V_X_VL, DL, VT, DAG.getUNDEF (VT),
4808
4813
DAG.getConstant (0 , DL, MVT::i32 ), VL);
4809
4814
4810
- MVT MaskVT = MVT::getVectorVT (MVT::i1, VT. getVectorElementCount () );
4815
+ MVT MaskVT = getMaskTypeFor (VT );
4811
4816
SDValue Mask = getAllOnesMask (VT, VL, DL, DAG);
4812
4817
SDValue VID = DAG.getNode (RISCVISD::VID_VL, DL, VT, Mask, VL);
4813
4818
SDValue SelectCond =
@@ -4841,8 +4846,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
4841
4846
4842
4847
SDValue PassThru = Op.getOperand (2 );
4843
4848
if (!IsUnmasked) {
4844
- MVT MaskVT =
4845
- MVT::getVectorVT (MVT::i1, ContainerVT.getVectorElementCount ());
4849
+ MVT MaskVT = getMaskTypeFor (ContainerVT);
4846
4850
Mask = convertToScalableVector (MaskVT, Mask, DAG, Subtarget);
4847
4851
PassThru = convertToScalableVector (ContainerVT, PassThru, DAG, Subtarget);
4848
4852
}
@@ -4939,8 +4943,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_VOID(SDValue Op,
4939
4943
4940
4944
Val = convertToScalableVector (ContainerVT, Val, DAG, Subtarget);
4941
4945
if (!IsUnmasked) {
4942
- MVT MaskVT =
4943
- MVT::getVectorVT (MVT::i1, ContainerVT.getVectorElementCount ());
4946
+ MVT MaskVT = getMaskTypeFor (ContainerVT);
4944
4947
Mask = convertToScalableVector (MaskVT, Mask, DAG, Subtarget);
4945
4948
}
4946
4949
@@ -5791,8 +5794,7 @@ SDValue RISCVTargetLowering::lowerMaskedLoad(SDValue Op,
5791
5794
ContainerVT = getContainerForFixedLengthVector (VT);
5792
5795
PassThru = convertToScalableVector (ContainerVT, PassThru, DAG, Subtarget);
5793
5796
if (!IsUnmasked) {
5794
- MVT MaskVT =
5795
- MVT::getVectorVT (MVT::i1, ContainerVT.getVectorElementCount ());
5797
+ MVT MaskVT = getMaskTypeFor (ContainerVT);
5796
5798
Mask = convertToScalableVector (MaskVT, Mask, DAG, Subtarget);
5797
5799
}
5798
5800
}
@@ -5858,8 +5860,7 @@ SDValue RISCVTargetLowering::lowerMaskedStore(SDValue Op,
5858
5860
5859
5861
Val = convertToScalableVector (ContainerVT, Val, DAG, Subtarget);
5860
5862
if (!IsUnmasked) {
5861
- MVT MaskVT =
5862
- MVT::getVectorVT (MVT::i1, ContainerVT.getVectorElementCount ());
5863
+ MVT MaskVT = getMaskTypeFor (ContainerVT);
5863
5864
Mask = convertToScalableVector (MaskVT, Mask, DAG, Subtarget);
5864
5865
}
5865
5866
}
@@ -5897,7 +5898,7 @@ RISCVTargetLowering::lowerFixedLengthVectorSetccToRVV(SDValue Op,
5897
5898
SDValue VL =
5898
5899
DAG.getConstant (VT.getVectorNumElements (), DL, Subtarget.getXLenVT ());
5899
5900
5900
- MVT MaskVT = MVT::getVectorVT (MVT::i1, ContainerVT. getVectorElementCount () );
5901
+ MVT MaskVT = getMaskTypeFor ( ContainerVT);
5901
5902
SDValue Mask = getAllOnesMask (ContainerVT, VL, DL, DAG);
5902
5903
5903
5904
SDValue Cmp = DAG.getNode (RISCVISD::SETCC_VL, DL, MaskVT, Op1, Op2,
@@ -6200,7 +6201,7 @@ SDValue RISCVTargetLowering::lowerVPFPIntConvOp(SDValue Op, SelectionDAG &DAG,
6200
6201
DstVT = getContainerForFixedLengthVector (DstVT);
6201
6202
SrcVT = getContainerForFixedLengthVector (SrcVT);
6202
6203
Src = convertToScalableVector (SrcVT, Src, DAG, Subtarget);
6203
- MVT MaskVT = MVT::getVectorVT (MVT::i1, DstVT. getVectorElementCount () );
6204
+ MVT MaskVT = getMaskTypeFor ( DstVT);
6204
6205
Mask = convertToScalableVector (MaskVT, Mask, DAG, Subtarget);
6205
6206
}
6206
6207
@@ -6413,8 +6414,7 @@ SDValue RISCVTargetLowering::lowerMaskedGather(SDValue Op,
6413
6414
Index = convertToScalableVector (IndexVT, Index, DAG, Subtarget);
6414
6415
6415
6416
if (!IsUnmasked) {
6416
- MVT MaskVT =
6417
- MVT::getVectorVT (MVT::i1, ContainerVT.getVectorElementCount ());
6417
+ MVT MaskVT = getMaskTypeFor (ContainerVT);
6418
6418
Mask = convertToScalableVector (MaskVT, Mask, DAG, Subtarget);
6419
6419
PassThru = convertToScalableVector (ContainerVT, PassThru, DAG, Subtarget);
6420
6420
}
@@ -6525,8 +6525,7 @@ SDValue RISCVTargetLowering::lowerMaskedScatter(SDValue Op,
6525
6525
Val = convertToScalableVector (ContainerVT, Val, DAG, Subtarget);
6526
6526
6527
6527
if (!IsUnmasked) {
6528
- MVT MaskVT =
6529
- MVT::getVectorVT (MVT::i1, ContainerVT.getVectorElementCount ());
6528
+ MVT MaskVT = getMaskTypeFor (ContainerVT);
6530
6529
Mask = convertToScalableVector (MaskVT, Mask, DAG, Subtarget);
6531
6530
}
6532
6531
}
@@ -8813,7 +8812,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
8813
8812
// The memory VT and the element type must match.
8814
8813
if (VecVT.getVectorElementType () == MemVT) {
8815
8814
SDLoc DL (N);
8816
- MVT MaskVT = MVT::getVectorVT (MVT::i1, VecVT. getVectorElementCount () );
8815
+ MVT MaskVT = getMaskTypeFor ( VecVT);
8817
8816
return DAG.getStoreVP (
8818
8817
Store->getChain (), DL, Src, Store->getBasePtr (), Store->getOffset (),
8819
8818
DAG.getConstant (1 , DL, MaskVT),
0 commit comments