Skip to content

[RISCV] Add DAG combine to convert (iN reduce.add (zext (vXi1 A to vXiN)) into vcpop.m #127497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 45 additions & 14 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1564,7 +1564,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::MUL, ISD::SDIV, ISD::UDIV,
ISD::SREM, ISD::UREM, ISD::INSERT_VECTOR_ELT,
ISD::ABS, ISD::CTPOP, ISD::VECTOR_SHUFFLE,
ISD::VSELECT});
ISD::VSELECT, ISD::VECREDUCE_ADD});

if (Subtarget.hasVendorXTHeadMemPair())
setTargetDAGCombine({ISD::LOAD, ISD::STORE});
Expand Down Expand Up @@ -18144,25 +18144,38 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
// (iX ctpop (bitcast (vXi1 A)))
// ->
// (zext (vcpop.m (nxvYi1 (insert_subvec (vXi1 A)))))
// and
// (iN reduce.add (zext (vXi1 A to vXiN))
// ->
// (zext (vcpop.m (nxvYi1 (insert_subvec (vXi1 A)))))
// FIXME: It's complicated to match all the variations of this after type
// legalization so we only handle the pre-type legalization pattern, but that
// requires the fixed vector type to be legal.
static SDValue combineScalarCTPOPToVCPOP(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
static SDValue combineToVCPOP(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
unsigned Opc = N->getOpcode();
assert((Opc == ISD::CTPOP || Opc == ISD::VECREDUCE_ADD) &&
"Unexpected opcode");
EVT VT = N->getValueType(0);
if (!VT.isScalarInteger())
return SDValue();

SDValue Src = N->getOperand(0);

// Peek through zero_extend. It doesn't change the count.
if (Src.getOpcode() == ISD::ZERO_EXTEND)
Src = Src.getOperand(0);
if (Opc == ISD::CTPOP) {
// Peek through zero_extend. It doesn't change the count.
if (Src.getOpcode() == ISD::ZERO_EXTEND)
Src = Src.getOperand(0);

if (Src.getOpcode() != ISD::BITCAST)
return SDValue();
if (Src.getOpcode() != ISD::BITCAST)
return SDValue();
Src = Src.getOperand(0);
} else if (Opc == ISD::VECREDUCE_ADD) {
if (Src.getOpcode() != ISD::ZERO_EXTEND)
return SDValue();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a subtle, nasty bug here.

Consider the case where the pattern is: (iN reduce.add (zext (vXi1 A to vXi4))

If runtime VLENB is such that the number of mask bits is greater than 16, this is not equal to the vcpop - due the wrapping behavior on the add reduce. You need to prove that the intermediate type is sufficiently wide to hold the element count of the mask source without overflow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thank you! I've made a fix that estimates possible maximum number of elements in mask and ensures that destination type is large enough to hold it. However, I've discovered that for this particular case (zext <16 x i1> to <16 x i4> + reduce.add) vcpop.m is still generated because type was extended to i8 during type legalization. This promotion also can be observed here: https://godbolt.org/z/avWEdcjvW, and it's confusing because if we have all-ones mask as an input, we should give zero result due to wrapping behaviour of add reduce, but the generated code will return 16 in this case. Also, I can't find any documentation on oveflowing behavior for integer reductions (intrinsics or ISD nodes)...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type of your function is declared as i4 without any attributes. Only the lowest 4 bits of the result are required to be valid. The upper bits can be any value. If you add a zeroext attribute to the return value an andi instruction will be generated to clear the upper bits.

define zeroext i4 @test_narrow_v16i1(<16 x i1> %x) {
entry:
    %a = zext <16 x i1> %x to <16 x i4>
    %b = call i4 @llvm.vector.reduce.add.v16i4(<16 x i4> %a)
    ret i4 %b
}

Type promotion makes it the responsibility of the consumer to zero or sign extend upper bits if needed. With no attributes, the consumer is an any_extend so the upper bits don't need to be touched. With zeroext the consumer is a zero_extend so the bits need to be cleared to match the semantics of the unpromoted zero_extend

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense now, thanks! Considering this producing vcpop.m in test_narrow_v16i1 seems correct - for all-ones mask we correctly set lowest 4 bits to zero

Src = Src.getOperand(0);
}

Src = Src.getOperand(0);
EVT SrcEVT = Src.getValueType();
if (!SrcEVT.isSimple())
return SDValue();
Expand All @@ -18172,11 +18185,28 @@ static SDValue combineScalarCTPOPToVCPOP(SDNode *N, SelectionDAG &DAG,
if (!SrcMVT.isVector() || SrcMVT.getVectorElementType() != MVT::i1)
return SDValue();

if (!useRVVForFixedLengthVectorVT(SrcMVT, Subtarget))
return SDValue();
// Check that destination type is large enough to hold result without
// overflow.
if (Opc == ISD::VECREDUCE_ADD) {
unsigned EltSize = SrcMVT.getScalarSizeInBits();
unsigned MinSize = SrcMVT.getSizeInBits().getKnownMinValue();
unsigned VectorBitsMax = Subtarget.getRealMaxVLen();
unsigned MaxVLMAX = SrcMVT.isFixedLengthVector()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a test for the fixed vector case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added (BTW, I think it will be a rare case on practice because InstCombine converts zext+reduce.add with fixed vector to bitcast + ctpop)

? SrcMVT.getVectorNumElements()
: RISCVTargetLowering::computeVLMAX(
VectorBitsMax, EltSize, MinSize);
if (VT.getFixedSizeInBits() < Log2_32(MaxVLMAX) + 1)
return SDValue();
}

MVT ContainerVT = getContainerForFixedLengthVector(DAG, SrcMVT, Subtarget);
Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
MVT ContainerVT = SrcMVT;
if (SrcMVT.isFixedLengthVector()) {
if (!useRVVForFixedLengthVectorVT(SrcMVT, Subtarget))
return SDValue();

ContainerVT = getContainerForFixedLengthVector(DAG, SrcMVT, Subtarget);
Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
}

SDLoc DL(N);
auto [Mask, VL] = getDefaultVLOps(SrcMVT, ContainerVT, DL, DAG, Subtarget);
Expand Down Expand Up @@ -19258,7 +19288,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
return SDValue();
}
case ISD::CTPOP:
if (SDValue V = combineScalarCTPOPToVCPOP(N, DAG, Subtarget))
case ISD::VECREDUCE_ADD:
if (SDValue V = combineToVCPOP(N, DAG, Subtarget))
return V;
break;
}
Expand Down
Loading