-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[RISCV] Add DAG combine to convert (iN reduce.add (zext (vXi1 A to vXiN)) into vcpop.m #127497
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1564,7 +1564,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, | |
ISD::MUL, ISD::SDIV, ISD::UDIV, | ||
ISD::SREM, ISD::UREM, ISD::INSERT_VECTOR_ELT, | ||
ISD::ABS, ISD::CTPOP, ISD::VECTOR_SHUFFLE, | ||
ISD::VSELECT}); | ||
ISD::VSELECT, ISD::VECREDUCE_ADD}); | ||
|
||
if (Subtarget.hasVendorXTHeadMemPair()) | ||
setTargetDAGCombine({ISD::LOAD, ISD::STORE}); | ||
|
@@ -18144,25 +18144,38 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG, | |
// (iX ctpop (bitcast (vXi1 A))) | ||
// -> | ||
// (zext (vcpop.m (nxvYi1 (insert_subvec (vXi1 A))))) | ||
// and | ||
// (iN reduce.add (zext (vXi1 A to vXiN)) | ||
// -> | ||
// (zext (vcpop.m (nxvYi1 (insert_subvec (vXi1 A))))) | ||
// FIXME: It's complicated to match all the variations of this after type | ||
// legalization so we only handle the pre-type legalization pattern, but that | ||
// requires the fixed vector type to be legal. | ||
static SDValue combineScalarCTPOPToVCPOP(SDNode *N, SelectionDAG &DAG, | ||
const RISCVSubtarget &Subtarget) { | ||
static SDValue combineToVCPOP(SDNode *N, SelectionDAG &DAG, | ||
const RISCVSubtarget &Subtarget) { | ||
unsigned Opc = N->getOpcode(); | ||
assert((Opc == ISD::CTPOP || Opc == ISD::VECREDUCE_ADD) && | ||
"Unexpected opcode"); | ||
EVT VT = N->getValueType(0); | ||
if (!VT.isScalarInteger()) | ||
return SDValue(); | ||
|
||
SDValue Src = N->getOperand(0); | ||
|
||
// Peek through zero_extend. It doesn't change the count. | ||
if (Src.getOpcode() == ISD::ZERO_EXTEND) | ||
Src = Src.getOperand(0); | ||
if (Opc == ISD::CTPOP) { | ||
// Peek through zero_extend. It doesn't change the count. | ||
if (Src.getOpcode() == ISD::ZERO_EXTEND) | ||
Src = Src.getOperand(0); | ||
|
||
if (Src.getOpcode() != ISD::BITCAST) | ||
return SDValue(); | ||
if (Src.getOpcode() != ISD::BITCAST) | ||
return SDValue(); | ||
Src = Src.getOperand(0); | ||
} else if (Opc == ISD::VECREDUCE_ADD) { | ||
if (Src.getOpcode() != ISD::ZERO_EXTEND) | ||
return SDValue(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a subtle, nasty bug here. Consider the case where the pattern is: (iN reduce.add (zext (vXi1 A to vXi4)) If runtime VLENB is such that the number of mask bits is greater than 16, this is not equal to the vcpop - due the wrapping behavior on the add reduce. You need to prove that the intermediate type is sufficiently wide to hold the element count of the mask source without overflow. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch, thank you! I've made a fix that estimates possible maximum number of elements in mask and ensures that destination type is large enough to hold it. However, I've discovered that for this particular case (zext <16 x i1> to <16 x i4> + reduce.add) vcpop.m is still generated because type was extended to i8 during type legalization. This promotion also can be observed here: https://godbolt.org/z/avWEdcjvW, and it's confusing because if we have all-ones mask as an input, we should give zero result due to wrapping behaviour of add reduce, but the generated code will return 16 in this case. Also, I can't find any documentation on oveflowing behavior for integer reductions (intrinsics or ISD nodes)... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The return type of your function is declared as i4 without any attributes. Only the lowest 4 bits of the result are required to be valid. The upper bits can be any value. If you add a
Type promotion makes it the responsibility of the consumer to zero or sign extend upper bits if needed. With no attributes, the consumer is an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense now, thanks! Considering this producing vcpop.m in test_narrow_v16i1 seems correct - for all-ones mask we correctly set lowest 4 bits to zero |
||
Src = Src.getOperand(0); | ||
} | ||
|
||
Src = Src.getOperand(0); | ||
EVT SrcEVT = Src.getValueType(); | ||
if (!SrcEVT.isSimple()) | ||
return SDValue(); | ||
|
@@ -18172,11 +18185,28 @@ static SDValue combineScalarCTPOPToVCPOP(SDNode *N, SelectionDAG &DAG, | |
if (!SrcMVT.isVector() || SrcMVT.getVectorElementType() != MVT::i1) | ||
return SDValue(); | ||
|
||
if (!useRVVForFixedLengthVectorVT(SrcMVT, Subtarget)) | ||
return SDValue(); | ||
// Check that destination type is large enough to hold result without | ||
// overflow. | ||
if (Opc == ISD::VECREDUCE_ADD) { | ||
unsigned EltSize = SrcMVT.getScalarSizeInBits(); | ||
unsigned MinSize = SrcMVT.getSizeInBits().getKnownMinValue(); | ||
unsigned VectorBitsMax = Subtarget.getRealMaxVLen(); | ||
unsigned MaxVLMAX = SrcMVT.isFixedLengthVector() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a test for the fixed vector case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added (BTW, I think it will be a rare case on practice because InstCombine converts zext+reduce.add with fixed vector to bitcast + ctpop) |
||
? SrcMVT.getVectorNumElements() | ||
: RISCVTargetLowering::computeVLMAX( | ||
VectorBitsMax, EltSize, MinSize); | ||
if (VT.getFixedSizeInBits() < Log2_32(MaxVLMAX) + 1) | ||
return SDValue(); | ||
} | ||
|
||
MVT ContainerVT = getContainerForFixedLengthVector(DAG, SrcMVT, Subtarget); | ||
Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget); | ||
MVT ContainerVT = SrcMVT; | ||
if (SrcMVT.isFixedLengthVector()) { | ||
if (!useRVVForFixedLengthVectorVT(SrcMVT, Subtarget)) | ||
return SDValue(); | ||
|
||
ContainerVT = getContainerForFixedLengthVector(DAG, SrcMVT, Subtarget); | ||
Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget); | ||
} | ||
|
||
SDLoc DL(N); | ||
auto [Mask, VL] = getDefaultVLOps(SrcMVT, ContainerVT, DL, DAG, Subtarget); | ||
|
@@ -19258,7 +19288,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, | |
return SDValue(); | ||
} | ||
case ISD::CTPOP: | ||
if (SDValue V = combineScalarCTPOPToVCPOP(N, DAG, Subtarget)) | ||
case ISD::VECREDUCE_ADD: | ||
if (SDValue V = combineToVCPOP(N, DAG, Subtarget)) | ||
return V; | ||
break; | ||
} | ||
|
Uh oh!
There was an error while loading. Please reload this page.