Skip to content

Commit 3dc7991

Browse files
authored
[RISCV] Add DAG combine to convert (iN reduce.add (zext (vXi1 A to vXiN)) into vcpop.m (#127497)
This patch combines (iN vector.reduce.add (zext (vXi1 A to vXiN)) into vcpop.m instruction (similarly to bitcast + ctpop pattern). It can be useful for counting number of set bits in scalable vector types, which can't be expressed with bitcast + ctpop (this was previously discussed here: #74294).
1 parent 9b4ad2f commit 3dc7991

File tree

2 files changed

+109
-157
lines changed

2 files changed

+109
-157
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,7 +1564,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
15641564
ISD::MUL, ISD::SDIV, ISD::UDIV,
15651565
ISD::SREM, ISD::UREM, ISD::INSERT_VECTOR_ELT,
15661566
ISD::ABS, ISD::CTPOP, ISD::VECTOR_SHUFFLE,
1567-
ISD::VSELECT});
1567+
ISD::VSELECT, ISD::VECREDUCE_ADD});
15681568

15691569
if (Subtarget.hasVendorXTHeadMemPair())
15701570
setTargetDAGCombine({ISD::LOAD, ISD::STORE});
@@ -18144,25 +18144,38 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
1814418144
// (iX ctpop (bitcast (vXi1 A)))
1814518145
// ->
1814618146
// (zext (vcpop.m (nxvYi1 (insert_subvec (vXi1 A)))))
18147+
// and
18148+
// (iN reduce.add (zext (vXi1 A to vXiN))
18149+
// ->
18150+
// (zext (vcpop.m (nxvYi1 (insert_subvec (vXi1 A)))))
1814718151
// FIXME: It's complicated to match all the variations of this after type
1814818152
// legalization so we only handle the pre-type legalization pattern, but that
1814918153
// requires the fixed vector type to be legal.
18150-
static SDValue combineScalarCTPOPToVCPOP(SDNode *N, SelectionDAG &DAG,
18151-
const RISCVSubtarget &Subtarget) {
18154+
static SDValue combineToVCPOP(SDNode *N, SelectionDAG &DAG,
18155+
const RISCVSubtarget &Subtarget) {
18156+
unsigned Opc = N->getOpcode();
18157+
assert((Opc == ISD::CTPOP || Opc == ISD::VECREDUCE_ADD) &&
18158+
"Unexpected opcode");
1815218159
EVT VT = N->getValueType(0);
1815318160
if (!VT.isScalarInteger())
1815418161
return SDValue();
1815518162

1815618163
SDValue Src = N->getOperand(0);
1815718164

18158-
// Peek through zero_extend. It doesn't change the count.
18159-
if (Src.getOpcode() == ISD::ZERO_EXTEND)
18160-
Src = Src.getOperand(0);
18165+
if (Opc == ISD::CTPOP) {
18166+
// Peek through zero_extend. It doesn't change the count.
18167+
if (Src.getOpcode() == ISD::ZERO_EXTEND)
18168+
Src = Src.getOperand(0);
1816118169

18162-
if (Src.getOpcode() != ISD::BITCAST)
18163-
return SDValue();
18170+
if (Src.getOpcode() != ISD::BITCAST)
18171+
return SDValue();
18172+
Src = Src.getOperand(0);
18173+
} else if (Opc == ISD::VECREDUCE_ADD) {
18174+
if (Src.getOpcode() != ISD::ZERO_EXTEND)
18175+
return SDValue();
18176+
Src = Src.getOperand(0);
18177+
}
1816418178

18165-
Src = Src.getOperand(0);
1816618179
EVT SrcEVT = Src.getValueType();
1816718180
if (!SrcEVT.isSimple())
1816818181
return SDValue();
@@ -18172,11 +18185,28 @@ static SDValue combineScalarCTPOPToVCPOP(SDNode *N, SelectionDAG &DAG,
1817218185
if (!SrcMVT.isVector() || SrcMVT.getVectorElementType() != MVT::i1)
1817318186
return SDValue();
1817418187

18175-
if (!useRVVForFixedLengthVectorVT(SrcMVT, Subtarget))
18176-
return SDValue();
18188+
// Check that destination type is large enough to hold result without
18189+
// overflow.
18190+
if (Opc == ISD::VECREDUCE_ADD) {
18191+
unsigned EltSize = SrcMVT.getScalarSizeInBits();
18192+
unsigned MinSize = SrcMVT.getSizeInBits().getKnownMinValue();
18193+
unsigned VectorBitsMax = Subtarget.getRealMaxVLen();
18194+
unsigned MaxVLMAX = SrcMVT.isFixedLengthVector()
18195+
? SrcMVT.getVectorNumElements()
18196+
: RISCVTargetLowering::computeVLMAX(
18197+
VectorBitsMax, EltSize, MinSize);
18198+
if (VT.getFixedSizeInBits() < Log2_32(MaxVLMAX) + 1)
18199+
return SDValue();
18200+
}
1817718201

18178-
MVT ContainerVT = getContainerForFixedLengthVector(DAG, SrcMVT, Subtarget);
18179-
Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
18202+
MVT ContainerVT = SrcMVT;
18203+
if (SrcMVT.isFixedLengthVector()) {
18204+
if (!useRVVForFixedLengthVectorVT(SrcMVT, Subtarget))
18205+
return SDValue();
18206+
18207+
ContainerVT = getContainerForFixedLengthVector(DAG, SrcMVT, Subtarget);
18208+
Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
18209+
}
1818018210

1818118211
SDLoc DL(N);
1818218212
auto [Mask, VL] = getDefaultVLOps(SrcMVT, ContainerVT, DL, DAG, Subtarget);
@@ -19258,7 +19288,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1925819288
return SDValue();
1925919289
}
1926019290
case ISD::CTPOP:
19261-
if (SDValue V = combineScalarCTPOPToVCPOP(N, DAG, Subtarget))
19291+
case ISD::VECREDUCE_ADD:
19292+
if (SDValue V = combineToVCPOP(N, DAG, Subtarget))
1926219293
return V;
1926319294
break;
1926419295
}

0 commit comments

Comments
 (0)