Skip to content

[RISCV] Use VP strided load in concat_vectors combine #98131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16205,18 +16205,10 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
if (MustNegateStride)
Stride = DAG.getNegative(Stride, DL, Stride.getValueType());

SDVTList VTs = DAG.getVTList({WideVecVT, MVT::Other});
SDValue IntID =
DAG.getTargetConstant(Intrinsic::riscv_masked_strided_load, DL,
Subtarget.getXLenVT());

SDValue AllOneMask =
DAG.getSplat(WideVecVT.changeVectorElementType(MVT::i1), DL,
DAG.getConstant(1, DL, MVT::i1));

SDValue Ops[] = {BaseLd->getChain(), IntID, DAG.getUNDEF(WideVecVT),
BaseLd->getBasePtr(), Stride, AllOneMask};

uint64_t MemSize;
if (auto *ConstStride = dyn_cast<ConstantSDNode>(Stride);
ConstStride && ConstStride->getSExtValue() >= 0)
Expand All @@ -16232,8 +16224,11 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
BaseLd->getPointerInfo(), BaseLd->getMemOperand()->getFlags(), MemSize,
Align);

SDValue StridedLoad = DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs,
Ops, WideVecVT, MMO);
SDValue StridedLoad = DAG.getStridedLoadVP(
WideVecVT, DL, BaseLd->getChain(), BaseLd->getBasePtr(), Stride,
AllOneMask,
DAG.getConstant(N->getNumOperands(), DL, Subtarget.getXLenVT()), MMO);

for (SDValue Ld : N->ops())
DAG.makeEquivalentMemoryOrdering(cast<LoadSDNode>(Ld), StridedLoad);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
define void @widen_2xv4i16(ptr %x, ptr %z) {
; CHECK-LABEL: widen_2xv4i16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 8, e16, m1, ta, ma
; CHECK-NEXT: vle16.v v8, (a0)
; CHECK-NEXT: vse16.v v8, (a1)
; CHECK-NEXT: vsetivli zero, 2, e64, m1, ta, ma
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may cause unaligned access error?

Copy link
Contributor Author

@lukel97 lukel97 Jul 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The combine checks for the alignment with TLI.isLegalStridedLoadStore(WideVecVT, Align), so I think this should be fine. This is what it was originally combining to anyway, the generic dag combine was just re-combining it to a smaller SEW

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the alignment of <4 x i16> is 64 bits, IIUC? Then this makes sense to me.

; CHECK-NEXT: vle64.v v8, (a0)
; CHECK-NEXT: vse64.v v8, (a1)
; CHECK-NEXT: ret
%a = load <4 x i16>, ptr %x
%b.gep = getelementptr i8, ptr %x, i64 8
Expand Down Expand Up @@ -52,9 +52,9 @@ define void @widen_3xv4i16(ptr %x, ptr %z) {
define void @widen_4xv4i16(ptr %x, ptr %z) {
; CHECK-LABEL: widen_4xv4i16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
; CHECK-NEXT: vle16.v v8, (a0)
; CHECK-NEXT: vse16.v v8, (a1)
; CHECK-NEXT: vsetivli zero, 4, e64, m2, ta, ma
; CHECK-NEXT: vle64.v v8, (a0)
; CHECK-NEXT: vse64.v v8, (a1)
; CHECK-NEXT: ret
%a = load <4 x i16>, ptr %x
%b.gep = getelementptr i8, ptr %x, i64 8
Expand Down Expand Up @@ -90,9 +90,9 @@ define void @widen_4xv4i16_unaligned(ptr %x, ptr %z) {
;
; RV64-MISALIGN-LABEL: widen_4xv4i16_unaligned:
; RV64-MISALIGN: # %bb.0:
; RV64-MISALIGN-NEXT: vsetivli zero, 16, e16, m2, ta, ma
; RV64-MISALIGN-NEXT: vle16.v v8, (a0)
; RV64-MISALIGN-NEXT: vse16.v v8, (a1)
; RV64-MISALIGN-NEXT: vsetivli zero, 4, e64, m2, ta, ma
; RV64-MISALIGN-NEXT: vle64.v v8, (a0)
; RV64-MISALIGN-NEXT: vse64.v v8, (a1)
; RV64-MISALIGN-NEXT: ret
%a = load <4 x i16>, ptr %x, align 1
%b.gep = getelementptr i8, ptr %x, i64 8
Expand Down
Loading