Skip to content

Commit a7a5044

Browse files
committed
[RISCV] Lower non-power-of-2 vector to nearest power-of-2 vector length with VP intrinsic
It's still early stage for this patch, but I would like to kick this out to demonstrate the possility of this approach, although it's mostly nullify by llvm#104689, but it can get some improve after adding more pattern, and will add later. The idea of this patch is lowee the non-power-of-2 vector to nearest power-of-2 vector length with VP intrinsic, and put vector insert and extrat for converting the type from/to the original vector type. Example: ``` define void @vls3i8(ptr align 8 %array) { entry: %1 = load <3 x i8>, ptr %array, align 1 %2 = add<3 x i8> %1, %1 store <3 x i8> %2, ptr %array, align 1 ret void } ``` ``` define void @vls3i8(ptr align 8 %array) #0 { entry: %0 = call <vscale x 4 x i8> @llvm.vp.load.nxv4i8.p0(ptr %array, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer), i32 3) %1 = call <3 x i8> @llvm.vector.extract.v3i8.nxv4i8(<vscale x 4 x i8> %0, i64 0) %2 = call <vscale x 4 x i8> @llvm.vector.insert.nxv4i8.v3i8(<vscale x 4 x i8> poison, <3 x i8> %1, i64 0) %3 = call <vscale x 4 x i8> @llvm.vector.insert.nxv4i8.v3i8(<vscale x 4 x i8> poison, <3 x i8> %1, i64 0) %4 = call <vscale x 4 x i8> @llvm.vp.add.nxv4i8(<vscale x 4 x i8> %2, <vscale x 4 x i8> %3, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> i nsertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer), i32 3) %5 = call <3 x i8> @llvm.vector.extract.v3i8.nxv4i8(<vscale x 4 x i8> %4, i64 0) %6 = call <vscale x 4 x i8> @llvm.vector.insert.nxv4i8.v3i8(<vscale x 4 x i8> poison, <3 x i8> %5, i64 0) call void @llvm.vp.store.nxv4i8.p0(<vscale x 4 x i8> %6, ptr %array, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer), i32 3) ret void } ```
1 parent 1193f7d commit a7a5044

File tree

10 files changed

+3591
-824
lines changed

10 files changed

+3591
-824
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "llvm/Analysis/VectorUtils.h"
2626
#include "llvm/CodeGen/ISDOpcodes.h"
2727
#include "llvm/IR/DataLayout.h"
28+
#include "llvm/Analysis/ValueTracking.h"
2829
#include "llvm/Support/ErrorHandling.h"
2930
#include "llvm/Support/TypeSize.h"
3031
#include "llvm/Support/raw_ostream.h"
@@ -5686,6 +5687,17 @@ SDValue DAGTypeLegalizer::WidenVecRes_EXTRACT_SUBVECTOR(SDNode *N) {
56865687
unsigned WidenNumElts = WidenVT.getVectorMinNumElements();
56875688
unsigned InNumElts = InVT.getVectorMinNumElements();
56885689
unsigned VTNumElts = VT.getVectorMinNumElements();
5690+
5691+
if (InVT.isScalableVector())
5692+
{
5693+
unsigned EltSize = InVT.getScalarType ().getFixedSizeInBits ();
5694+
5695+
unsigned MinVScale = getVScaleRange(&DAG.getMachineFunction ().getFunction(), 64)
5696+
.getUnsignedMin().getZExtValue ();
5697+
InNumElts = InNumElts * MinVScale;
5698+
}
5699+
5700+
56895701
assert(IdxVal % VTNumElts == 0 &&
56905702
"Expected Idx to be a multiple of subvector minimum vector length");
56915703
if (IdxVal % WidenNumElts == 0 && IdxVal + WidenNumElts < InNumElts)

llvm/lib/Target/RISCV/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ add_llvm_target(RISCVCodeGen
5757
RISCVTargetObjectFile.cpp
5858
RISCVTargetTransformInfo.cpp
5959
RISCVVectorPeephole.cpp
60+
RISCVLegalizeNonPowerOf2Vector.cpp
6061
GISel/RISCVCallLowering.cpp
6162
GISel/RISCVInstructionSelector.cpp
6263
GISel/RISCVLegalizerInfo.cpp

llvm/lib/Target/RISCV/RISCV.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ void initializeRISCVO0PreLegalizerCombinerPass(PassRegistry &);
9999

100100
FunctionPass *createRISCVPreLegalizerCombiner();
101101
void initializeRISCVPreLegalizerCombinerPass(PassRegistry &);
102+
103+
FunctionPass *createRISCVLegalizeNonPowerOf2Vector();
104+
void initializeRISCVLegalizeNonPowerOf2VectorPass(PassRegistry &);
102105
} // namespace llvm
103106

104107
#endif
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
#include "RISCV.h"
2+
#include "RISCVTargetMachine.h"
3+
#include "llvm/CodeGen/TargetPassConfig.h"
4+
#include "llvm/IR/Constants.h"
5+
#include "llvm/IR/Function.h"
6+
#include "llvm/IR/IRBuilder.h"
7+
#include "llvm/IR/Instructions.h"
8+
#include "llvm/IR/Module.h"
9+
#include "llvm/IR/Type.h"
10+
#include "llvm/IR/VectorBuilder.h"
11+
#include "llvm/InitializePasses.h"
12+
#include "llvm/Pass.h"
13+
#include "llvm/Support/raw_ostream.h"
14+
15+
#include <vector>
16+
17+
using namespace llvm;
18+
19+
#define DEBUG_TYPE "riscv-legalize-non-power-of-2-vector"
20+
#define PASS_NAME "Legalize non-power-of-2 vector type"
21+
22+
namespace {
23+
class RISCVLegalizeNonPowerOf2Vector : public FunctionPass {
24+
const RISCVSubtarget *ST;
25+
unsigned MinVScale;
26+
27+
public:
28+
static char ID;
29+
RISCVLegalizeNonPowerOf2Vector() : FunctionPass(ID) {}
30+
31+
void getAnalysisUsage(AnalysisUsage &AU) const override {
32+
AU.setPreservesCFG();
33+
AU.addRequired<TargetPassConfig>();
34+
}
35+
36+
bool runOnFunction(Function &F) override;
37+
StringRef getPassName() const override { return PASS_NAME; }
38+
39+
private:
40+
FixedVectorType *extracUsedFixedVectorType(const Instruction &I) const;
41+
42+
bool isTargetType(FixedVectorType *VecTy) const;
43+
44+
ScalableVectorType *
45+
getContainerForFixedLengthVector(FixedVectorType *FixedVecTy);
46+
};
47+
} // namespace
48+
49+
FixedVectorType *RISCVLegalizeNonPowerOf2Vector::extracUsedFixedVectorType(
50+
const Instruction &I) const {
51+
if (isa<FixedVectorType>(I.getType())) {
52+
return cast<FixedVectorType>(I.getType());
53+
} else if (isa<StoreInst>(I) &&
54+
isa<FixedVectorType>(
55+
cast<StoreInst>(&I)->getValueOperand()->getType())) {
56+
return cast<FixedVectorType>(
57+
cast<StoreInst>(&I)->getValueOperand()->getType());
58+
}
59+
return nullptr;
60+
}
61+
62+
ScalableVectorType *
63+
RISCVLegalizeNonPowerOf2Vector::getContainerForFixedLengthVector(
64+
FixedVectorType *FixedVecTy) {
65+
// TODO: Consider vscale_range to pick a better/smaller type.
66+
//
67+
uint64_t NumElts =
68+
std::max<uint64_t>((NextPowerOf2 (FixedVecTy->getNumElements()) / MinVScale), 1);
69+
70+
Type *ElementType = FixedVecTy->getElementType();
71+
72+
if (ElementType->isIntegerTy(1))
73+
NumElts = std::max(NumElts, 8UL);
74+
75+
return ScalableVectorType::get(ElementType, NumElts);
76+
}
77+
78+
bool RISCVLegalizeNonPowerOf2Vector::isTargetType(
79+
FixedVectorType *VecTy) const {
80+
if (isPowerOf2_32(VecTy->getNumElements()))
81+
return false;
82+
83+
Type *EltTy = VecTy->getElementType();
84+
85+
if (EltTy->isIntegerTy(1))
86+
return false;
87+
88+
if (EltTy->isIntegerTy(64))
89+
return ST->hasVInstructionsI64();
90+
else if (EltTy->isFloatTy())
91+
return ST->hasVInstructionsF32();
92+
else if (EltTy->isDoubleTy())
93+
return ST->hasVInstructionsF64();
94+
else if (EltTy->isHalfTy())
95+
return ST->hasVInstructionsF16Minimal();
96+
else if (EltTy->isBFloatTy())
97+
return ST->hasVInstructionsBF16Minimal();
98+
99+
return (EltTy->isIntegerTy(1) || EltTy->isIntegerTy(8) ||
100+
EltTy->isIntegerTy(16) || EltTy->isIntegerTy(32));
101+
}
102+
103+
bool RISCVLegalizeNonPowerOf2Vector::runOnFunction(Function &F) {
104+
105+
if (skipFunction(F))
106+
return false;
107+
108+
auto &TPC = getAnalysis<TargetPassConfig>();
109+
auto &TM = TPC.getTM<RISCVTargetMachine>();
110+
ST = &TM.getSubtarget<RISCVSubtarget>(F);
111+
112+
if (!ST->hasVInstructions())
113+
return false;
114+
115+
auto Attr = F.getFnAttribute(Attribute::VScaleRange);
116+
if (Attr.isValid()) {
117+
MinVScale = Attr.getVScaleRangeMin ();
118+
} else {
119+
unsigned MinVLen = ST->getRealMinVLen();
120+
if (MinVLen < RISCV::RVVBitsPerBlock)
121+
return false;
122+
MinVScale = MinVLen / RISCV::RVVBitsPerBlock;
123+
AttrBuilder AB(F.getContext());
124+
AB.addVScaleRangeAttr(MinVScale,
125+
std::optional<unsigned>());
126+
127+
F.addFnAttr (AB.getAttribute(Attribute::VScaleRange));
128+
}
129+
130+
bool Modified = false;
131+
std::vector<Instruction *> ToBeRemoved;
132+
for (auto &BB : F) {
133+
for (auto &I : make_range(BB.rbegin(), BB.rend())) {
134+
if (auto VecTy = extracUsedFixedVectorType(I)) {
135+
if (!isTargetType(VecTy)) {
136+
continue;
137+
}
138+
139+
Value *I64Zero = ConstantInt::get(Type::getInt64Ty(F.getContext()), 0);
140+
141+
// Replace fixed length vector with scalable vector
142+
IRBuilder<> Builder(&I);
143+
VectorBuilder VecBuilder(Builder);
144+
VecBuilder.setStaticVL(VecTy->getNumElements());
145+
VectorType *NewVecTy = getContainerForFixedLengthVector(VecTy);
146+
VecBuilder.setMask(Builder.CreateVectorSplat(
147+
NewVecTy->getElementCount(), Builder.getTrue()));
148+
149+
if (auto *BinOp = dyn_cast<BinaryOperator>(&I)) {
150+
Value *Op1 = BinOp->getOperand(0);
151+
Value *Op2 = BinOp->getOperand(1);
152+
Value *NewOp1 = Builder.CreateInsertVector(
153+
NewVecTy, PoisonValue::get(NewVecTy), Op1, I64Zero);
154+
Value *NewOp2 = Builder.CreateInsertVector(
155+
NewVecTy, PoisonValue::get(NewVecTy), Op2, I64Zero);
156+
Value *NewBinOp = VecBuilder.createVectorInstruction(
157+
BinOp->getOpcode(), NewVecTy, {NewOp1, NewOp2});
158+
Value *FinalResult =
159+
Builder.CreateExtractVector(VecTy, NewBinOp, I64Zero);
160+
BinOp->replaceAllUsesWith(FinalResult);
161+
ToBeRemoved.push_back(BinOp);
162+
Modified = true;
163+
} else if (auto *StoreOp = dyn_cast<StoreInst>(&I)) {
164+
Value *Val = StoreOp->getOperand(0);
165+
Value *Addr = StoreOp->getOperand(1);
166+
Value *NewVal = Builder.CreateInsertVector(
167+
NewVecTy, PoisonValue::get(NewVecTy), Val, I64Zero);
168+
Value *NewStoreOp = VecBuilder.createVectorInstruction(
169+
StoreOp->getOpcode(), NewVecTy, {NewVal, Addr});
170+
StoreOp->replaceAllUsesWith(NewStoreOp);
171+
ToBeRemoved.push_back(StoreOp);
172+
} else if (auto *LoadOp = dyn_cast<LoadInst>(&I)) {
173+
Value *Addr = LoadOp->getOperand(0);
174+
Value *NewLoadOp = VecBuilder.createVectorInstruction(
175+
LoadOp->getOpcode(), NewVecTy, {Addr});
176+
Value *FinalResult =
177+
Builder.CreateExtractVector(VecTy, NewLoadOp, I64Zero);
178+
LoadOp->replaceAllUsesWith(FinalResult);
179+
ToBeRemoved.push_back(LoadOp);
180+
}
181+
}
182+
}
183+
}
184+
for_each(ToBeRemoved.begin(), ToBeRemoved.end(),
185+
[](Instruction *I) { I->eraseFromParent(); });
186+
return Modified;
187+
}
188+
189+
char RISCVLegalizeNonPowerOf2Vector::ID = 0;
190+
191+
INITIALIZE_PASS_BEGIN(RISCVLegalizeNonPowerOf2Vector, DEBUG_TYPE, PASS_NAME,
192+
false, false)
193+
INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
194+
INITIALIZE_PASS_END(RISCVLegalizeNonPowerOf2Vector, DEBUG_TYPE, PASS_NAME,
195+
false, false)
196+
197+
FunctionPass *llvm::createRISCVLegalizeNonPowerOf2Vector() {
198+
return new RISCVLegalizeNonPowerOf2Vector();
199+
}

llvm/lib/Target/RISCV/RISCVTargetMachine.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeRISCVTarget() {
128128
initializeRISCVDAGToDAGISelLegacyPass(*PR);
129129
initializeRISCVMoveMergePass(*PR);
130130
initializeRISCVPushPopOptPass(*PR);
131+
initializeRISCVLegalizeNonPowerOf2VectorPass(*PR);
131132
}
132133

133134
static StringRef computeDataLayout(const Triple &TT,
@@ -452,6 +453,7 @@ bool RISCVPassConfig::addPreISel() {
452453
void RISCVPassConfig::addCodeGenPrepare() {
453454
if (getOptLevel() != CodeGenOptLevel::None)
454455
addPass(createTypePromotionLegacyPass());
456+
addPass(createRISCVLegalizeNonPowerOf2Vector());
455457
TargetPassConfig::addCodeGenPrepare();
456458
}
457459

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-abs.ll

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,25 @@ define void @abs_v6i16(ptr %x) {
3939
; CHECK: # %bb.0:
4040
; CHECK-NEXT: vsetivli zero, 6, e16, m1, ta, ma
4141
; CHECK-NEXT: vle16.v v8, (a0)
42+
; CHECK-NEXT: vslidedown.vi v9, v8, 1
43+
; CHECK-NEXT: vmv.x.s a1, v9
44+
; CHECK-NEXT: vmv.x.s a2, v8
4245
; CHECK-NEXT: vsetivli zero, 8, e16, m1, ta, ma
46+
; CHECK-NEXT: vmv.v.x v9, a2
47+
; CHECK-NEXT: vslide1down.vx v9, v9, a1
48+
; CHECK-NEXT: vslidedown.vi v10, v8, 2
49+
; CHECK-NEXT: vmv.x.s a1, v10
50+
; CHECK-NEXT: vslide1down.vx v9, v9, a1
51+
; CHECK-NEXT: vslidedown.vi v10, v8, 3
52+
; CHECK-NEXT: vmv.x.s a1, v10
53+
; CHECK-NEXT: vslide1down.vx v9, v9, a1
54+
; CHECK-NEXT: vslidedown.vi v10, v8, 4
55+
; CHECK-NEXT: vmv.x.s a1, v10
56+
; CHECK-NEXT: vslide1down.vx v9, v9, a1
57+
; CHECK-NEXT: vslidedown.vi v8, v8, 5
58+
; CHECK-NEXT: vmv.x.s a1, v8
59+
; CHECK-NEXT: vslide1down.vx v8, v9, a1
60+
; CHECK-NEXT: vslidedown.vi v8, v8, 2
4361
; CHECK-NEXT: vrsub.vi v9, v8, 0
4462
; CHECK-NEXT: vsetivli zero, 6, e16, m1, ta, ma
4563
; CHECK-NEXT: vmax.vv v8, v8, v9

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-extract.ll

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,18 @@ define i64 @extractelt_v3i64(ptr %x) nounwind {
220220
; RV32: # %bb.0:
221221
; RV32-NEXT: vsetivli zero, 3, e64, m2, ta, ma
222222
; RV32-NEXT: vle64.v v8, (a0)
223-
; RV32-NEXT: vsetivli zero, 1, e32, m2, ta, ma
223+
; RV32-NEXT: vslidedown.vi v8, v8, 2
224+
; RV32-NEXT: vmv.x.s a0, v8
225+
; RV32-NEXT: vmv.s.x v10, a0
226+
; RV32-NEXT: li a0, 32
227+
; RV32-NEXT: vsetivli zero, 1, e64, m2, ta, ma
228+
; RV32-NEXT: vsrl.vx v8, v8, a0
229+
; RV32-NEXT: vmv.x.s a0, v8
230+
; RV32-NEXT: vsetivli zero, 8, e32, m2, ta, ma
231+
; RV32-NEXT: vmv.v.x v8, a0
232+
; RV32-NEXT: vsetivli zero, 5, e32, m2, tu, ma
233+
; RV32-NEXT: vslideup.vi v8, v10, 4
234+
; RV32-NEXT: vsetvli zero, zero, e32, m2, ta, ma
224235
; RV32-NEXT: vslidedown.vi v10, v8, 4
225236
; RV32-NEXT: vmv.x.s a0, v10
226237
; RV32-NEXT: vslidedown.vi v8, v8, 5
@@ -567,10 +578,37 @@ define i64 @extractelt_v3i64_idx(ptr %x, i32 zeroext %idx) nounwind {
567578
; RV32: # %bb.0:
568579
; RV32-NEXT: vsetivli zero, 3, e64, m2, ta, ma
569580
; RV32-NEXT: vle64.v v8, (a0)
570-
; RV32-NEXT: vsetivli zero, 4, e64, m2, ta, ma
571581
; RV32-NEXT: vadd.vv v8, v8, v8
582+
; RV32-NEXT: li a0, 32
583+
; RV32-NEXT: vsetivli zero, 1, e64, m2, ta, ma
584+
; RV32-NEXT: vsrl.vx v10, v8, a0
585+
; RV32-NEXT: vmv.x.s a2, v10
586+
; RV32-NEXT: vmv.x.s a3, v8
587+
; RV32-NEXT: vsetivli zero, 8, e32, m2, ta, ma
588+
; RV32-NEXT: vmv.v.x v10, a3
589+
; RV32-NEXT: vslide1down.vx v10, v10, a2
590+
; RV32-NEXT: vsetivli zero, 1, e64, m2, ta, ma
591+
; RV32-NEXT: vslidedown.vi v12, v8, 1
592+
; RV32-NEXT: vmv.x.s a2, v12
593+
; RV32-NEXT: vsetivli zero, 8, e32, m2, ta, ma
594+
; RV32-NEXT: vslide1down.vx v10, v10, a2
595+
; RV32-NEXT: vsetivli zero, 1, e64, m2, ta, ma
596+
; RV32-NEXT: vsrl.vx v12, v12, a0
597+
; RV32-NEXT: vmv.x.s a2, v12
598+
; RV32-NEXT: vsetivli zero, 8, e32, m2, ta, ma
599+
; RV32-NEXT: vslide1down.vx v10, v10, a2
600+
; RV32-NEXT: vsetivli zero, 1, e64, m2, ta, ma
601+
; RV32-NEXT: vslidedown.vi v8, v8, 2
602+
; RV32-NEXT: vmv.x.s a2, v8
603+
; RV32-NEXT: vsetivli zero, 8, e32, m2, ta, ma
604+
; RV32-NEXT: vslide1down.vx v10, v10, a2
605+
; RV32-NEXT: vsetivli zero, 1, e64, m2, ta, ma
606+
; RV32-NEXT: vsrl.vx v8, v8, a0
607+
; RV32-NEXT: vmv.x.s a0, v8
608+
; RV32-NEXT: vsetivli zero, 8, e32, m2, ta, ma
609+
; RV32-NEXT: vslide1down.vx v8, v10, a0
610+
; RV32-NEXT: vslidedown.vi v8, v8, 2
572611
; RV32-NEXT: add a1, a1, a1
573-
; RV32-NEXT: vsetivli zero, 1, e32, m2, ta, ma
574612
; RV32-NEXT: vslidedown.vx v10, v8, a1
575613
; RV32-NEXT: vmv.x.s a0, v10
576614
; RV32-NEXT: addi a1, a1, 1
@@ -582,8 +620,18 @@ define i64 @extractelt_v3i64_idx(ptr %x, i32 zeroext %idx) nounwind {
582620
; RV64: # %bb.0:
583621
; RV64-NEXT: vsetivli zero, 3, e64, m2, ta, ma
584622
; RV64-NEXT: vle64.v v8, (a0)
585-
; RV64-NEXT: vsetivli zero, 4, e64, m2, ta, ma
586623
; RV64-NEXT: vadd.vv v8, v8, v8
624+
; RV64-NEXT: vsetivli zero, 1, e64, m1, ta, ma
625+
; RV64-NEXT: vslidedown.vi v10, v8, 1
626+
; RV64-NEXT: vmv.x.s a0, v10
627+
; RV64-NEXT: vmv.x.s a2, v8
628+
; RV64-NEXT: vsetivli zero, 4, e64, m2, ta, ma
629+
; RV64-NEXT: vmv.v.x v10, a2
630+
; RV64-NEXT: vslide1down.vx v10, v10, a0
631+
; RV64-NEXT: vslidedown.vi v8, v8, 2
632+
; RV64-NEXT: vmv.x.s a0, v8
633+
; RV64-NEXT: vslide1down.vx v8, v10, a0
634+
; RV64-NEXT: vslidedown.vi v8, v8, 1
587635
; RV64-NEXT: vslidedown.vx v8, v8, a1
588636
; RV64-NEXT: vmv.x.s a0, v8
589637
; RV64-NEXT: ret

0 commit comments

Comments
 (0)