@@ -13527,6 +13527,40 @@ static bool legalizeScatterGatherIndexType(SDLoc DL, SDValue &Index,
13527
13527
return true;
13528
13528
}
13529
13529
13530
+ /// Match the index vector of a scatter or gather node as the shuffle mask
13531
+ /// which performs the rearrangement if possible. Will only match if
13532
+ /// all lanes are touched, and thus replacing the scatter or gather with
13533
+ /// a unit strided access and shuffle is legal.
13534
+ static bool matchIndexAsShuffle(EVT VT, SDValue Index, SDValue Mask,
13535
+ SmallVector<int> &ShuffleMask) {
13536
+ if (!ISD::isConstantSplatVectorAllOnes(Mask.getNode()))
13537
+ return false;
13538
+ if (!ISD::isBuildVectorOfConstantSDNodes(Index.getNode()))
13539
+ return false;
13540
+
13541
+ const unsigned ElementSize = VT.getScalarStoreSize();
13542
+ const unsigned NumElems = VT.getVectorNumElements();
13543
+
13544
+ // Create the shuffle mask and check all bits active
13545
+ assert(ShuffleMask.empty());
13546
+ BitVector ActiveLanes(NumElems);
13547
+ for (unsigned i = 0; i < Index->getNumOperands(); i++) {
13548
+ // TODO: We've found an active bit of UB, and could be
13549
+ // more aggressive here if desired.
13550
+ if (Index->getOperand(i)->isUndef())
13551
+ return false;
13552
+ uint64_t C = Index->getConstantOperandVal(i);
13553
+ if (C % ElementSize != 0)
13554
+ return false;
13555
+ C = C / ElementSize;
13556
+ if (C >= NumElems)
13557
+ return false;
13558
+ ShuffleMask.push_back(C);
13559
+ ActiveLanes.set(C);
13560
+ }
13561
+ return ActiveLanes.all();
13562
+ }
13563
+
13530
13564
SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
13531
13565
DAGCombinerInfo &DCI) const {
13532
13566
SelectionDAG &DAG = DCI.DAG;
@@ -13874,6 +13908,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
13874
13908
}
13875
13909
case ISD::MGATHER: {
13876
13910
const auto *MGN = dyn_cast<MaskedGatherSDNode>(N);
13911
+ const EVT VT = N->getValueType(0);
13877
13912
SDValue Index = MGN->getIndex();
13878
13913
SDValue ScaleOp = MGN->getScale();
13879
13914
ISD::MemIndexType IndexType = MGN->getIndexType();
@@ -13894,6 +13929,19 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
13894
13929
{MGN->getChain(), MGN->getPassThru(), MGN->getMask(),
13895
13930
MGN->getBasePtr(), Index, ScaleOp},
13896
13931
MGN->getMemOperand(), IndexType, MGN->getExtensionType());
13932
+
13933
+ SmallVector<int> ShuffleMask;
13934
+ if (MGN->getExtensionType() == ISD::NON_EXTLOAD &&
13935
+ matchIndexAsShuffle(VT, Index, MGN->getMask(), ShuffleMask)) {
13936
+ SDValue Load = DAG.getMaskedLoad(VT, DL, MGN->getChain(),
13937
+ MGN->getBasePtr(), DAG.getUNDEF(XLenVT),
13938
+ MGN->getMask(), DAG.getUNDEF(VT),
13939
+ MGN->getMemoryVT(), MGN->getMemOperand(),
13940
+ ISD::UNINDEXED, ISD::NON_EXTLOAD);
13941
+ SDValue Shuffle =
13942
+ DAG.getVectorShuffle(VT, DL, Load, DAG.getUNDEF(VT), ShuffleMask);
13943
+ return DAG.getMergeValues({Shuffle, Load.getValue(1)}, DL);
13944
+ }
13897
13945
break;
13898
13946
}
13899
13947
case ISD::MSCATTER:{
@@ -13918,6 +13966,18 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
13918
13966
{MSN->getChain(), MSN->getValue(), MSN->getMask(), MSN->getBasePtr(),
13919
13967
Index, ScaleOp},
13920
13968
MSN->getMemOperand(), IndexType, MSN->isTruncatingStore());
13969
+
13970
+ EVT VT = MSN->getValue()->getValueType(0);
13971
+ SmallVector<int> ShuffleMask;
13972
+ if (!MSN->isTruncatingStore() &&
13973
+ matchIndexAsShuffle(VT, Index, MSN->getMask(), ShuffleMask)) {
13974
+ SDValue Shuffle = DAG.getVectorShuffle(VT, DL, MSN->getValue(),
13975
+ DAG.getUNDEF(VT), ShuffleMask);
13976
+ return DAG.getMaskedStore(MSN->getChain(), DL, Shuffle, MSN->getBasePtr(),
13977
+ DAG.getUNDEF(XLenVT), MSN->getMask(),
13978
+ MSN->getMemoryVT(), MSN->getMemOperand(),
13979
+ ISD::UNINDEXED, false);
13980
+ }
13921
13981
break;
13922
13982
}
13923
13983
case ISD::VP_GATHER: {
0 commit comments