Skip to content

[RISCV] Refactor performCONCAT_VECTORSCombine. NFC #69068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 16, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 34 additions & 58 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13785,11 +13785,10 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
return SDValue();

EVT BaseLdVT = BaseLd->getValueType(0);
SDValue BasePtr = BaseLd->getBasePtr();

// Go through the loads and check that they're strided
SmallVector<SDValue> Ptrs;
Ptrs.push_back(BasePtr);
SmallVector<LoadSDNode *> Lds;
Lds.push_back(BaseLd);
Align Align = BaseLd->getAlign();
for (SDValue Op : N->ops().drop_front()) {
auto *Ld = dyn_cast<LoadSDNode>(Op);
Expand All @@ -13798,60 +13797,38 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
Ld->getValueType(0) != BaseLdVT)
return SDValue();

Ptrs.push_back(Ld->getBasePtr());
Lds.push_back(Ld);

// The common alignment is the most restrictive (smallest) of all the loads
Align = std::min(Align, Ld->getAlign());
}

auto matchForwardStrided = [](ArrayRef<SDValue> Ptrs) {
SDValue Stride;
for (auto Idx : enumerate(Ptrs)) {
if (Idx.index() == 0)
continue;
SDValue Ptr = Idx.value();
// Check that each load's pointer is (add LastPtr, Stride)
if (Ptr.getOpcode() != ISD::ADD ||
Ptr.getOperand(0) != Ptrs[Idx.index()-1])
return SDValue();
SDValue Offset = Ptr.getOperand(1);
if (!Stride)
Stride = Offset;
else if (Offset != Stride)
return SDValue();
}
return Stride;
};
auto matchReverseStrided = [](ArrayRef<SDValue> Ptrs) {
SDValue Stride;
for (auto Idx : enumerate(Ptrs)) {
if (Idx.index() == Ptrs.size() - 1)
continue;
SDValue Ptr = Idx.value();
// Check that each load's pointer is (add NextPtr, Stride)
if (Ptr.getOpcode() != ISD::ADD ||
Ptr.getOperand(0) != Ptrs[Idx.index()+1])
return SDValue();
SDValue Offset = Ptr.getOperand(1);
if (!Stride)
Stride = Offset;
else if (Offset != Stride)
return SDValue();
}
return Stride;
using PtrDiff = std::pair<SDValue, bool>;
auto GetPtrDiff = [](LoadSDNode *Ld1,
LoadSDNode *Ld2) -> std::optional<PtrDiff> {
SDValue P1 = Ld1->getBasePtr();
SDValue P2 = Ld2->getBasePtr();
if (P2.getOpcode() == ISD::ADD && P2.getOperand(0) == P1)
return {{P2.getOperand(1), false}};
if (P1.getOpcode() == ISD::ADD && P1.getOperand(0) == P2)
return {{P1.getOperand(1), true}};

return std::nullopt;
};

bool Reversed = false;
SDValue Stride = matchForwardStrided(Ptrs);
if (!Stride) {
Stride = matchReverseStrided(Ptrs);
Reversed = true;
// TODO: At this point, we've successfully matched a generalized gather
// load. Maybe we should emit that, and then move the specialized
// matchers above and below into a DAG combine?
if (!Stride)
// Get the distance between the first and second loads
auto BaseDiff = GetPtrDiff(Lds[0], Lds[1]);
if (!BaseDiff)
return SDValue();

// Check all the loads are the same distance apart
for (auto *It = Lds.begin() + 1; It != Lds.end() - 1; It++)
if (GetPtrDiff(*It, *std::next(It)) != BaseDiff)
return SDValue();
}

// TODO: At this point, we've successfully matched a generalized gather
// load. Maybe we should emit that, and then move the specialized
// matchers above and below into a DAG combine?

// Get the widened scalar type, e.g. v4i8 -> i64
unsigned WideScalarBitWidth =
Expand All @@ -13867,26 +13844,25 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
if (!TLI.isLegalStridedLoadStore(WideVecVT, Align))
return SDValue();

auto [Stride, MustNegateStride] = *BaseDiff;
if (MustNegateStride)
Stride = DAG.getNegative(Stride, DL, Stride.getValueType());

SDVTList VTs = DAG.getVTList({WideVecVT, MVT::Other});
SDValue IntID =
DAG.getTargetConstant(Intrinsic::riscv_masked_strided_load, DL,
Subtarget.getXLenVT());
if (Reversed)
Stride = DAG.getNegative(Stride, DL, Stride->getValueType(0));

SDValue AllOneMask =
DAG.getSplat(WideVecVT.changeVectorElementType(MVT::i1), DL,
DAG.getConstant(1, DL, MVT::i1));

SDValue Ops[] = {BaseLd->getChain(),
IntID,
DAG.getUNDEF(WideVecVT),
BasePtr,
Stride,
AllOneMask};
SDValue Ops[] = {BaseLd->getChain(), IntID, DAG.getUNDEF(WideVecVT),
BaseLd->getBasePtr(), Stride, AllOneMask};

uint64_t MemSize;
if (auto *ConstStride = dyn_cast<ConstantSDNode>(Stride);
ConstStride && !Reversed && ConstStride->getSExtValue() >= 0)
ConstStride && ConstStride->getSExtValue() >= 0)
// total size = (elsize * n) + (stride - elsize) * (n-1)
// = elsize + stride * (n-1)
MemSize = WideScalarVT.getSizeInBits() +
Expand Down