Skip to content

Commit cda245a

Browse files
authored
[RISCV] Expand vp.stride.load to splat of a scalar load. (#98140)
It's a similar patch as a214c52 for vp.stride.load. Some targets prefer pattern (vmv.v.x (load)) instead of vlse with zero stride. It's IR version of #97798.
1 parent 9324c95 commit cda245a

File tree

3 files changed

+140
-6
lines changed

3 files changed

+140
-6
lines changed

llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
#include "llvm/ADT/Statistic.h"
1919
#include "llvm/Analysis/ValueTracking.h"
2020
#include "llvm/CodeGen/TargetPassConfig.h"
21+
#include "llvm/IR/Dominators.h"
2122
#include "llvm/IR/IRBuilder.h"
2223
#include "llvm/IR/InstVisitor.h"
2324
#include "llvm/IR/Intrinsics.h"
25+
#include "llvm/IR/IntrinsicsRISCV.h"
2426
#include "llvm/IR/PatternMatch.h"
2527
#include "llvm/InitializePasses.h"
2628
#include "llvm/Pass.h"
@@ -35,6 +37,7 @@ namespace {
3537
class RISCVCodeGenPrepare : public FunctionPass,
3638
public InstVisitor<RISCVCodeGenPrepare, bool> {
3739
const DataLayout *DL;
40+
const DominatorTree *DT;
3841
const RISCVSubtarget *ST;
3942

4043
public:
@@ -48,12 +51,14 @@ class RISCVCodeGenPrepare : public FunctionPass,
4851

4952
void getAnalysisUsage(AnalysisUsage &AU) const override {
5053
AU.setPreservesCFG();
54+
AU.addRequired<DominatorTreeWrapperPass>();
5155
AU.addRequired<TargetPassConfig>();
5256
}
5357

5458
bool visitInstruction(Instruction &I) { return false; }
5559
bool visitAnd(BinaryOperator &BO);
5660
bool visitIntrinsicInst(IntrinsicInst &I);
61+
bool expandVPStrideLoad(IntrinsicInst &I);
5762
};
5863

5964
} // end anonymous namespace
@@ -128,6 +133,9 @@ bool RISCVCodeGenPrepare::visitAnd(BinaryOperator &BO) {
128133
// Which eliminates the scalar -> vector -> scalar crossing during instruction
129134
// selection.
130135
bool RISCVCodeGenPrepare::visitIntrinsicInst(IntrinsicInst &I) {
136+
if (expandVPStrideLoad(I))
137+
return true;
138+
131139
if (I.getIntrinsicID() != Intrinsic::vector_reduce_fadd)
132140
return false;
133141

@@ -155,6 +163,47 @@ bool RISCVCodeGenPrepare::visitIntrinsicInst(IntrinsicInst &I) {
155163
return true;
156164
}
157165

166+
bool RISCVCodeGenPrepare::expandVPStrideLoad(IntrinsicInst &II) {
167+
if (ST->hasOptimizedZeroStrideLoad())
168+
return false;
169+
170+
Value *BasePtr, *VL;
171+
using namespace PatternMatch;
172+
if (!match(&II, m_Intrinsic<Intrinsic::experimental_vp_strided_load>(
173+
m_Value(BasePtr), m_Zero(), m_AllOnes(), m_Value(VL))))
174+
return false;
175+
176+
if (!isKnownNonZero(VL, {*DL, DT, nullptr, &II}))
177+
return false;
178+
179+
auto *VTy = cast<VectorType>(II.getType());
180+
181+
IRBuilder<> Builder(&II);
182+
183+
// Extend VL from i32 to XLen if needed.
184+
if (ST->is64Bit())
185+
VL = Builder.CreateZExt(VL, Builder.getInt64Ty());
186+
187+
Type *STy = VTy->getElementType();
188+
Value *Val = Builder.CreateLoad(STy, BasePtr);
189+
const auto &TLI = *ST->getTargetLowering();
190+
Value *Res;
191+
192+
// TODO: Also support fixed/illegal vector types to splat with evl = vl.
193+
if (isa<ScalableVectorType>(VTy) && TLI.isTypeLegal(EVT::getEVT(VTy))) {
194+
unsigned VMVOp = STy->isFloatingPointTy() ? Intrinsic::riscv_vfmv_v_f
195+
: Intrinsic::riscv_vmv_v_x;
196+
Res = Builder.CreateIntrinsic(VMVOp, {VTy, VL->getType()},
197+
{PoisonValue::get(VTy), Val, VL});
198+
} else {
199+
Res = Builder.CreateVectorSplat(VTy->getElementCount(), Val);
200+
}
201+
202+
II.replaceAllUsesWith(Res);
203+
II.eraseFromParent();
204+
return true;
205+
}
206+
158207
bool RISCVCodeGenPrepare::runOnFunction(Function &F) {
159208
if (skipFunction(F))
160209
return false;
@@ -164,6 +213,7 @@ bool RISCVCodeGenPrepare::runOnFunction(Function &F) {
164213
ST = &TM.getSubtarget<RISCVSubtarget>(F);
165214

166215
DL = &F.getDataLayout();
216+
DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
167217

168218
bool MadeChange = false;
169219
for (auto &BB : F)

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-vpload.ll

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
22
; RUN: llc -mtriple=riscv32 -mattr=+m,+d,+zfh,+v,+zvfh \
3-
; RUN: -verify-machineinstrs < %s \
4-
; RUN: | FileCheck %s --check-prefixes=CHECK,CHECK-RV32
3+
; RUN: -verify-machineinstrs < %s | FileCheck %s \
4+
; RUN: -check-prefixes=CHECK,CHECK-RV32
55
; RUN: llc -mtriple=riscv64 -mattr=+m,+d,+zfh,+v,+zvfh \
6-
; RUN: -verify-machineinstrs < %s \
7-
; RUN: | FileCheck %s --check-prefixes=CHECK,CHECK-RV64
6+
; RUN: -verify-machineinstrs < %s | FileCheck %s \
7+
; RUN: -check-prefixes=CHECK,CHECK-RV64,CHECK-OPT
8+
; RUN: llc -mtriple=riscv32 -mattr=+m,+d,+zfh,+v,+zvfh,+no-optimized-zero-stride-load \
9+
; RUN: -verify-machineinstrs < %s | FileCheck %s \
10+
; RUN: -check-prefixes=CHECK,CHECK-RV32,CHECK-NOOPT
11+
; RUN: llc -mtriple=riscv64 -mattr=+m,+d,+zfh,+v,+zvfh,+no-optimized-zero-stride-load \
12+
; RUN: -verify-machineinstrs < %s | FileCheck %s \
13+
; RUN: -check-prefixes=CHECK,CHECK-RV64,CHECK-NOOPT
814

915
declare <2 x i8> @llvm.experimental.vp.strided.load.v2i8.p0.i8(ptr, i8, <2 x i1>, i32)
1016

@@ -626,3 +632,39 @@ define <33 x double> @strided_load_v33f64(ptr %ptr, i64 %stride, <33 x i1> %mask
626632
}
627633

628634
declare <33 x double> @llvm.experimental.vp.strided.load.v33f64.p0.i64(ptr, i64, <33 x i1>, i32)
635+
636+
; Test unmasked integer zero strided
637+
define <4 x i8> @zero_strided_unmasked_vpload_4i8_i8(ptr %ptr) {
638+
; CHECK-OPT-LABEL: zero_strided_unmasked_vpload_4i8_i8:
639+
; CHECK-OPT: # %bb.0:
640+
; CHECK-OPT-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
641+
; CHECK-OPT-NEXT: vlse8.v v8, (a0), zero
642+
; CHECK-OPT-NEXT: ret
643+
;
644+
; CHECK-NOOPT-LABEL: zero_strided_unmasked_vpload_4i8_i8:
645+
; CHECK-NOOPT: # %bb.0:
646+
; CHECK-NOOPT-NEXT: lbu a0, 0(a0)
647+
; CHECK-NOOPT-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
648+
; CHECK-NOOPT-NEXT: vmv.v.x v8, a0
649+
; CHECK-NOOPT-NEXT: ret
650+
%load = call <4 x i8> @llvm.experimental.vp.strided.load.4i8.p0.i8(ptr %ptr, i8 0, <4 x i1> splat (i1 true), i32 4)
651+
ret <4 x i8> %load
652+
}
653+
654+
; Test unmasked float zero strided
655+
define <4 x half> @zero_strided_unmasked_vpload_4f16(ptr %ptr) {
656+
; CHECK-OPT-LABEL: zero_strided_unmasked_vpload_4f16:
657+
; CHECK-OPT: # %bb.0:
658+
; CHECK-OPT-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
659+
; CHECK-OPT-NEXT: vlse16.v v8, (a0), zero
660+
; CHECK-OPT-NEXT: ret
661+
;
662+
; CHECK-NOOPT-LABEL: zero_strided_unmasked_vpload_4f16:
663+
; CHECK-NOOPT: # %bb.0:
664+
; CHECK-NOOPT-NEXT: flh fa5, 0(a0)
665+
; CHECK-NOOPT-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
666+
; CHECK-NOOPT-NEXT: vfmv.v.f v8, fa5
667+
; CHECK-NOOPT-NEXT: ret
668+
%load = call <4 x half> @llvm.experimental.vp.strided.load.4f16.p0.i32(ptr %ptr, i32 0, <4 x i1> splat (i1 true), i32 4)
669+
ret <4 x half> %load
670+
}

llvm/test/CodeGen/RISCV/rvv/strided-vpload.ll

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
22
; RUN: llc -mtriple=riscv32 -mattr=+m,+d,+zfh,+v,+zvfh \
33
; RUN: -verify-machineinstrs < %s | FileCheck %s \
4-
; RUN: -check-prefixes=CHECK,CHECK-RV32
4+
; RUN: -check-prefixes=CHECK,CHECK-RV32,CHECK-OPT
55
; RUN: llc -mtriple=riscv64 -mattr=+m,+d,+zfh,+v,+zvfh \
66
; RUN: -verify-machineinstrs < %s | FileCheck %s \
7-
; RUN: -check-prefixes=CHECK,CHECK-RV64
7+
; RUN: -check-prefixes=CHECK,CHECK-RV64,CHECK-OPT
8+
; RUN: llc -mtriple=riscv32 -mattr=+m,+d,+zfh,+v,+zvfh,+no-optimized-zero-stride-load \
9+
; RUN: -verify-machineinstrs < %s | FileCheck %s \
10+
; RUN: -check-prefixes=CHECK,CHECK-RV32,CHECK-NOOPT
11+
; RUN: llc -mtriple=riscv64 -mattr=+m,+d,+zfh,+v,+zvfh,+no-optimized-zero-stride-load \
12+
; RUN: -verify-machineinstrs < %s | FileCheck %s \
13+
; RUN: -check-prefixes=CHECK,CHECK-RV64,CHECK-NOOPT
814

915
declare <vscale x 1 x i8> @llvm.experimental.vp.strided.load.nxv1i8.p0.i8(ptr, i8, <vscale x 1 x i1>, i32)
1016

@@ -780,3 +786,39 @@ define <vscale x 16 x double> @strided_load_nxv17f64(ptr %ptr, i64 %stride, <vsc
780786
declare <vscale x 17 x double> @llvm.experimental.vp.strided.load.nxv17f64.p0.i64(ptr, i64, <vscale x 17 x i1>, i32)
781787
declare <vscale x 1 x double> @llvm.experimental.vector.extract.nxv1f64(<vscale x 17 x double> %vec, i64 %idx)
782788
declare <vscale x 16 x double> @llvm.experimental.vector.extract.nxv16f64(<vscale x 17 x double> %vec, i64 %idx)
789+
790+
; Test unmasked integer zero strided
791+
define <vscale x 1 x i8> @zero_strided_unmasked_vpload_nxv1i8_i8(ptr %ptr) {
792+
; CHECK-OPT-LABEL: zero_strided_unmasked_vpload_nxv1i8_i8:
793+
; CHECK-OPT: # %bb.0:
794+
; CHECK-OPT-NEXT: vsetivli zero, 4, e8, mf8, ta, ma
795+
; CHECK-OPT-NEXT: vlse8.v v8, (a0), zero
796+
; CHECK-OPT-NEXT: ret
797+
;
798+
; CHECK-NOOPT-LABEL: zero_strided_unmasked_vpload_nxv1i8_i8:
799+
; CHECK-NOOPT: # %bb.0:
800+
; CHECK-NOOPT-NEXT: lbu a0, 0(a0)
801+
; CHECK-NOOPT-NEXT: vsetivli zero, 4, e8, mf8, ta, ma
802+
; CHECK-NOOPT-NEXT: vmv.v.x v8, a0
803+
; CHECK-NOOPT-NEXT: ret
804+
%load = call <vscale x 1 x i8> @llvm.experimental.vp.strided.load.nxv1i8.p0.i8(ptr %ptr, i8 0, <vscale x 1 x i1> splat (i1 true), i32 4)
805+
ret <vscale x 1 x i8> %load
806+
}
807+
808+
; Test unmasked float zero strided
809+
define <vscale x 1 x half> @zero_strided_unmasked_vpload_nxv1f16(ptr %ptr) {
810+
; CHECK-OPT-LABEL: zero_strided_unmasked_vpload_nxv1f16:
811+
; CHECK-OPT: # %bb.0:
812+
; CHECK-OPT-NEXT: vsetivli zero, 4, e16, mf4, ta, ma
813+
; CHECK-OPT-NEXT: vlse16.v v8, (a0), zero
814+
; CHECK-OPT-NEXT: ret
815+
;
816+
; CHECK-NOOPT-LABEL: zero_strided_unmasked_vpload_nxv1f16:
817+
; CHECK-NOOPT: # %bb.0:
818+
; CHECK-NOOPT-NEXT: flh fa5, 0(a0)
819+
; CHECK-NOOPT-NEXT: vsetivli zero, 4, e16, mf4, ta, ma
820+
; CHECK-NOOPT-NEXT: vfmv.v.f v8, fa5
821+
; CHECK-NOOPT-NEXT: ret
822+
%load = call <vscale x 1 x half> @llvm.experimental.vp.strided.load.nxv1f16.p0.i32(ptr %ptr, i32 0, <vscale x 1 x i1> splat (i1 true), i32 4)
823+
ret <vscale x 1 x half> %load
824+
}

0 commit comments

Comments
 (0)