Skip to content

Commit ff2622b

Browse files
committed
[RISCV] Optimize gather/scatter to unit-stride memop + shuffle (#66279)
If we have a gather or a scatter whose index describes a permutation of the lanes, we can lower this as a shuffle + a unit strided memory operation. For RISCV, this replaces a indexed load/store with a unit strided memory operation and a vrgather (at worst). I did not bother to implement the vp.scatter and vp.gather variants of these transforms because they'd only be legal when EVL was VLMAX. Given that, they should have been transformed to the non-vp variants anyways. I haven't checked to see if they actually are.
1 parent ac182de commit ff2622b

File tree

3 files changed

+529
-21
lines changed

3 files changed

+529
-21
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13527,6 +13527,40 @@ static bool legalizeScatterGatherIndexType(SDLoc DL, SDValue &Index,
1352713527
return true;
1352813528
}
1352913529

13530+
/// Match the index vector of a scatter or gather node as the shuffle mask
13531+
/// which performs the rearrangement if possible. Will only match if
13532+
/// all lanes are touched, and thus replacing the scatter or gather with
13533+
/// a unit strided access and shuffle is legal.
13534+
static bool matchIndexAsShuffle(EVT VT, SDValue Index, SDValue Mask,
13535+
SmallVector<int> &ShuffleMask) {
13536+
if (!ISD::isConstantSplatVectorAllOnes(Mask.getNode()))
13537+
return false;
13538+
if (!ISD::isBuildVectorOfConstantSDNodes(Index.getNode()))
13539+
return false;
13540+
13541+
const unsigned ElementSize = VT.getScalarStoreSize();
13542+
const unsigned NumElems = VT.getVectorNumElements();
13543+
13544+
// Create the shuffle mask and check all bits active
13545+
assert(ShuffleMask.empty());
13546+
BitVector ActiveLanes(NumElems);
13547+
for (unsigned i = 0; i < Index->getNumOperands(); i++) {
13548+
// TODO: We've found an active bit of UB, and could be
13549+
// more aggressive here if desired.
13550+
if (Index->getOperand(i)->isUndef())
13551+
return false;
13552+
uint64_t C = Index->getConstantOperandVal(i);
13553+
if (C % ElementSize != 0)
13554+
return false;
13555+
C = C / ElementSize;
13556+
if (C >= NumElems)
13557+
return false;
13558+
ShuffleMask.push_back(C);
13559+
ActiveLanes.set(C);
13560+
}
13561+
return ActiveLanes.all();
13562+
}
13563+
1353013564
SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1353113565
DAGCombinerInfo &DCI) const {
1353213566
SelectionDAG &DAG = DCI.DAG;
@@ -13874,6 +13908,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1387413908
}
1387513909
case ISD::MGATHER: {
1387613910
const auto *MGN = dyn_cast<MaskedGatherSDNode>(N);
13911+
const EVT VT = N->getValueType(0);
1387713912
SDValue Index = MGN->getIndex();
1387813913
SDValue ScaleOp = MGN->getScale();
1387913914
ISD::MemIndexType IndexType = MGN->getIndexType();
@@ -13894,6 +13929,19 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1389413929
{MGN->getChain(), MGN->getPassThru(), MGN->getMask(),
1389513930
MGN->getBasePtr(), Index, ScaleOp},
1389613931
MGN->getMemOperand(), IndexType, MGN->getExtensionType());
13932+
13933+
SmallVector<int> ShuffleMask;
13934+
if (MGN->getExtensionType() == ISD::NON_EXTLOAD &&
13935+
matchIndexAsShuffle(VT, Index, MGN->getMask(), ShuffleMask)) {
13936+
SDValue Load = DAG.getMaskedLoad(VT, DL, MGN->getChain(),
13937+
MGN->getBasePtr(), DAG.getUNDEF(XLenVT),
13938+
MGN->getMask(), DAG.getUNDEF(VT),
13939+
MGN->getMemoryVT(), MGN->getMemOperand(),
13940+
ISD::UNINDEXED, ISD::NON_EXTLOAD);
13941+
SDValue Shuffle =
13942+
DAG.getVectorShuffle(VT, DL, Load, DAG.getUNDEF(VT), ShuffleMask);
13943+
return DAG.getMergeValues({Shuffle, Load.getValue(1)}, DL);
13944+
}
1389713945
break;
1389813946
}
1389913947
case ISD::MSCATTER:{
@@ -13918,6 +13966,18 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1391813966
{MSN->getChain(), MSN->getValue(), MSN->getMask(), MSN->getBasePtr(),
1391913967
Index, ScaleOp},
1392013968
MSN->getMemOperand(), IndexType, MSN->isTruncatingStore());
13969+
13970+
EVT VT = MSN->getValue()->getValueType(0);
13971+
SmallVector<int> ShuffleMask;
13972+
if (!MSN->isTruncatingStore() &&
13973+
matchIndexAsShuffle(VT, Index, MSN->getMask(), ShuffleMask)) {
13974+
SDValue Shuffle = DAG.getVectorShuffle(VT, DL, MSN->getValue(),
13975+
DAG.getUNDEF(VT), ShuffleMask);
13976+
return DAG.getMaskedStore(MSN->getChain(), DL, Shuffle, MSN->getBasePtr(),
13977+
DAG.getUNDEF(XLenVT), MSN->getMask(),
13978+
MSN->getMemoryVT(), MSN->getMemOperand(),
13979+
ISD::UNINDEXED, false);
13980+
}
1392113981
break;
1392213982
}
1392313983
case ISD::VP_GATHER: {

0 commit comments

Comments
 (0)