Skip to content

Commit 4903c11

Browse files
authored
[RISCV] Support memcmp expansion for vectors
This patch adds the support of generating vector instructions for `memcmp`. This implementation is inspired by X86's. We convert integer comparisons (eq/ne only) into vector comparisons and do a vector reduction and to get the result. The range of supported load sizes is (XLEN, VLEN * LMUL8] and non-power-of-2 types are not supported. Fixes #143294. Reviewers: lukel97, asb, preames, topperc, dtcxzyw Reviewed By: topperc, lukel97 Pull Request: #114517
1 parent 483d196 commit 4903c11

File tree

4 files changed

+2474
-552
lines changed

4 files changed

+2474
-552
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16147,17 +16147,80 @@ static bool narrowIndex(SDValue &N, ISD::MemIndexType IndexType, SelectionDAG &D
1614716147
return true;
1614816148
}
1614916149

16150+
/// Try to map an integer comparison with size > XLEN to vector instructions
16151+
/// before type legalization splits it up into chunks.
16152+
static SDValue
16153+
combineVectorSizedSetCCEquality(EVT VT, SDValue X, SDValue Y, ISD::CondCode CC,
16154+
const SDLoc &DL, SelectionDAG &DAG,
16155+
const RISCVSubtarget &Subtarget) {
16156+
assert(ISD::isIntEqualitySetCC(CC) && "Bad comparison predicate");
16157+
16158+
if (!Subtarget.hasVInstructions())
16159+
return SDValue();
16160+
16161+
MVT XLenVT = Subtarget.getXLenVT();
16162+
EVT OpVT = X.getValueType();
16163+
// We're looking for an oversized integer equality comparison.
16164+
if (!OpVT.isScalarInteger())
16165+
return SDValue();
16166+
16167+
unsigned OpSize = OpVT.getSizeInBits();
16168+
// TODO: Support non-power-of-2 types.
16169+
if (!isPowerOf2_32(OpSize))
16170+
return SDValue();
16171+
16172+
// The size should be larger than XLen and smaller than the maximum vector
16173+
// size.
16174+
if (OpSize <= Subtarget.getXLen() ||
16175+
OpSize > Subtarget.getRealMinVLen() *
16176+
Subtarget.getMaxLMULForFixedLengthVectors())
16177+
return SDValue();
16178+
16179+
// Don't perform this combine if constructing the vector will be expensive.
16180+
auto IsVectorBitCastCheap = [](SDValue X) {
16181+
X = peekThroughBitcasts(X);
16182+
return isa<ConstantSDNode>(X) || X.getValueType().isVector() ||
16183+
X.getOpcode() == ISD::LOAD;
16184+
};
16185+
if (!IsVectorBitCastCheap(X) || !IsVectorBitCastCheap(Y))
16186+
return SDValue();
16187+
16188+
if (DAG.getMachineFunction().getFunction().hasFnAttribute(
16189+
Attribute::NoImplicitFloat))
16190+
return SDValue();
16191+
16192+
unsigned VecSize = OpSize / 8;
16193+
EVT VecVT = MVT::getVectorVT(MVT::i8, VecSize);
16194+
EVT CmpVT = MVT::getVectorVT(MVT::i1, VecSize);
16195+
16196+
SDValue VecX = DAG.getBitcast(VecVT, X);
16197+
SDValue VecY = DAG.getBitcast(VecVT, Y);
16198+
SDValue Cmp = DAG.getSetCC(DL, CmpVT, VecX, VecY, ISD::SETNE);
16199+
return DAG.getSetCC(DL, VT, DAG.getNode(ISD::VECREDUCE_OR, DL, XLenVT, Cmp),
16200+
DAG.getConstant(0, DL, XLenVT), CC);
16201+
}
16202+
1615016203
// Replace (seteq (i64 (and X, 0xffffffff)), C1) with
1615116204
// (seteq (i64 (sext_inreg (X, i32)), C1')) where C1' is C1 sign extended from
1615216205
// bit 31. Same for setne. C1' may be cheaper to materialize and the sext_inreg
1615316206
// can become a sext.w instead of a shift pair.
1615416207
static SDValue performSETCCCombine(SDNode *N, SelectionDAG &DAG,
1615516208
const RISCVSubtarget &Subtarget) {
16209+
SDLoc dl(N);
1615616210
SDValue N0 = N->getOperand(0);
1615716211
SDValue N1 = N->getOperand(1);
1615816212
EVT VT = N->getValueType(0);
1615916213
EVT OpVT = N0.getValueType();
1616016214

16215+
ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
16216+
// Looking for an equality compare.
16217+
if (!isIntEqualitySetCC(Cond))
16218+
return SDValue();
16219+
16220+
if (SDValue V =
16221+
combineVectorSizedSetCCEquality(VT, N0, N1, Cond, dl, DAG, Subtarget))
16222+
return V;
16223+
1616116224
if (OpVT != MVT::i64 || !Subtarget.is64Bit())
1616216225
return SDValue();
1616316226

@@ -16172,11 +16235,6 @@ static SDValue performSETCCCombine(SDNode *N, SelectionDAG &DAG,
1617216235
N0.getConstantOperandVal(1) != UINT64_C(0xffffffff))
1617316236
return SDValue();
1617416237

16175-
// Looking for an equality compare.
16176-
ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
16177-
if (!isIntEqualitySetCC(Cond))
16178-
return SDValue();
16179-
1618016238
// Don't do this if the sign bit is provably zero, it will be turned back into
1618116239
// an AND.
1618216240
APInt SignMask = APInt::getOneBitSet(64, 31);
@@ -16185,7 +16243,6 @@ static SDValue performSETCCCombine(SDNode *N, SelectionDAG &DAG,
1618516243

1618616244
const APInt &C1 = N1C->getAPIntValue();
1618716245

16188-
SDLoc dl(N);
1618916246
// If the constant is larger than 2^32 - 1 it is impossible for both sides
1619016247
// to be equal.
1619116248
if (C1.getActiveBits() > 32)

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2952,5 +2952,22 @@ RISCVTTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const {
29522952
Options.LoadSizes = {4, 2, 1};
29532953
Options.AllowedTailExpansions = {3};
29542954
}
2955+
2956+
if (IsZeroCmp && ST->hasVInstructions()) {
2957+
unsigned RealMinVLen = ST->getRealMinVLen();
2958+
// Support Fractional LMULs if the lengths are larger than XLen.
2959+
// TODO: Support non-power-of-2 types.
2960+
for (unsigned FLMUL = 8; FLMUL >= 2; FLMUL /= 2) {
2961+
unsigned Len = RealMinVLen / FLMUL;
2962+
if (Len > ST->getXLen())
2963+
Options.LoadSizes.insert(Options.LoadSizes.begin(), Len / 8);
2964+
}
2965+
for (unsigned LMUL = 1; LMUL <= ST->getMaxLMULForFixedLengthVectors();
2966+
LMUL *= 2) {
2967+
unsigned Len = RealMinVLen * LMUL;
2968+
if (Len > ST->getXLen())
2969+
Options.LoadSizes.insert(Options.LoadSizes.begin(), Len / 8);
2970+
}
2971+
}
29552972
return Options;
29562973
}

0 commit comments

Comments
 (0)