@@ -1564,7 +1564,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
1564
1564
ISD::MUL, ISD::SDIV, ISD::UDIV,
1565
1565
ISD::SREM, ISD::UREM, ISD::INSERT_VECTOR_ELT,
1566
1566
ISD::ABS, ISD::CTPOP, ISD::VECTOR_SHUFFLE,
1567
- ISD::VSELECT});
1567
+ ISD::VSELECT, ISD::VECREDUCE_ADD });
1568
1568
1569
1569
if (Subtarget.hasVendorXTHeadMemPair())
1570
1570
setTargetDAGCombine({ISD::LOAD, ISD::STORE});
@@ -18144,25 +18144,38 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
18144
18144
// (iX ctpop (bitcast (vXi1 A)))
18145
18145
// ->
18146
18146
// (zext (vcpop.m (nxvYi1 (insert_subvec (vXi1 A)))))
18147
+ // and
18148
+ // (iN reduce.add (zext (vXi1 A to vXiN))
18149
+ // ->
18150
+ // (zext (vcpop.m (nxvYi1 (insert_subvec (vXi1 A)))))
18147
18151
// FIXME: It's complicated to match all the variations of this after type
18148
18152
// legalization so we only handle the pre-type legalization pattern, but that
18149
18153
// requires the fixed vector type to be legal.
18150
- static SDValue combineScalarCTPOPToVCPOP(SDNode *N, SelectionDAG &DAG,
18151
- const RISCVSubtarget &Subtarget) {
18154
+ static SDValue combineToVCPOP(SDNode *N, SelectionDAG &DAG,
18155
+ const RISCVSubtarget &Subtarget) {
18156
+ unsigned Opc = N->getOpcode();
18157
+ assert((Opc == ISD::CTPOP || Opc == ISD::VECREDUCE_ADD) &&
18158
+ "Unexpected opcode");
18152
18159
EVT VT = N->getValueType(0);
18153
18160
if (!VT.isScalarInteger())
18154
18161
return SDValue();
18155
18162
18156
18163
SDValue Src = N->getOperand(0);
18157
18164
18158
- // Peek through zero_extend. It doesn't change the count.
18159
- if (Src.getOpcode() == ISD::ZERO_EXTEND)
18160
- Src = Src.getOperand(0);
18165
+ if (Opc == ISD::CTPOP) {
18166
+ // Peek through zero_extend. It doesn't change the count.
18167
+ if (Src.getOpcode() == ISD::ZERO_EXTEND)
18168
+ Src = Src.getOperand(0);
18161
18169
18162
- if (Src.getOpcode() != ISD::BITCAST)
18163
- return SDValue();
18170
+ if (Src.getOpcode() != ISD::BITCAST)
18171
+ return SDValue();
18172
+ Src = Src.getOperand(0);
18173
+ } else if (Opc == ISD::VECREDUCE_ADD) {
18174
+ if (Src.getOpcode() != ISD::ZERO_EXTEND)
18175
+ return SDValue();
18176
+ Src = Src.getOperand(0);
18177
+ }
18164
18178
18165
- Src = Src.getOperand(0);
18166
18179
EVT SrcEVT = Src.getValueType();
18167
18180
if (!SrcEVT.isSimple())
18168
18181
return SDValue();
@@ -18172,11 +18185,28 @@ static SDValue combineScalarCTPOPToVCPOP(SDNode *N, SelectionDAG &DAG,
18172
18185
if (!SrcMVT.isVector() || SrcMVT.getVectorElementType() != MVT::i1)
18173
18186
return SDValue();
18174
18187
18175
- if (!useRVVForFixedLengthVectorVT(SrcMVT, Subtarget))
18176
- return SDValue();
18188
+ // Check that destination type is large enough to hold result without
18189
+ // overflow.
18190
+ if (Opc == ISD::VECREDUCE_ADD) {
18191
+ unsigned EltSize = SrcMVT.getScalarSizeInBits();
18192
+ unsigned MinSize = SrcMVT.getSizeInBits().getKnownMinValue();
18193
+ unsigned VectorBitsMax = Subtarget.getRealMaxVLen();
18194
+ unsigned MaxVLMAX = SrcMVT.isFixedLengthVector()
18195
+ ? SrcMVT.getVectorNumElements()
18196
+ : RISCVTargetLowering::computeVLMAX(
18197
+ VectorBitsMax, EltSize, MinSize);
18198
+ if (VT.getFixedSizeInBits() < Log2_32(MaxVLMAX) + 1)
18199
+ return SDValue();
18200
+ }
18177
18201
18178
- MVT ContainerVT = getContainerForFixedLengthVector(DAG, SrcMVT, Subtarget);
18179
- Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
18202
+ MVT ContainerVT = SrcMVT;
18203
+ if (SrcMVT.isFixedLengthVector()) {
18204
+ if (!useRVVForFixedLengthVectorVT(SrcMVT, Subtarget))
18205
+ return SDValue();
18206
+
18207
+ ContainerVT = getContainerForFixedLengthVector(DAG, SrcMVT, Subtarget);
18208
+ Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
18209
+ }
18180
18210
18181
18211
SDLoc DL(N);
18182
18212
auto [Mask, VL] = getDefaultVLOps(SrcMVT, ContainerVT, DL, DAG, Subtarget);
@@ -19258,7 +19288,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
19258
19288
return SDValue();
19259
19289
}
19260
19290
case ISD::CTPOP:
19261
- if (SDValue V = combineScalarCTPOPToVCPOP(N, DAG, Subtarget))
19291
+ case ISD::VECREDUCE_ADD:
19292
+ if (SDValue V = combineToVCPOP(N, DAG, Subtarget))
19262
19293
return V;
19263
19294
break;
19264
19295
}
0 commit comments