Skip to content

Commit 1215a7b

Browse files
vporposmallp-o-p
authored andcommitted
[SandboxVec][Legality] Per opcode checks (llvm#114145)
This patch adds more opcode-specific legality checks.
1 parent 254a3ae commit 1215a7b

File tree

7 files changed

+580
-10
lines changed

7 files changed

+580
-10
lines changed

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_LEGALITY_H
1414

1515
#include "llvm/ADT/ArrayRef.h"
16+
#include "llvm/Analysis/ScalarEvolution.h"
17+
#include "llvm/IR/DataLayout.h"
1618
#include "llvm/Support/Casting.h"
1719
#include "llvm/Support/raw_ostream.h"
1820

@@ -33,6 +35,9 @@ enum class ResultReason {
3335
DiffTypes,
3436
DiffMathFlags,
3537
DiffWrapFlags,
38+
NotConsecutive,
39+
Unimplemented,
40+
Infeasible,
3641
};
3742

3843
#ifndef NDEBUG
@@ -59,6 +64,12 @@ struct ToStr {
5964
return "DiffMathFlags";
6065
case ResultReason::DiffWrapFlags:
6166
return "DiffWrapFlags";
67+
case ResultReason::NotConsecutive:
68+
return "NotConsecutive";
69+
case ResultReason::Unimplemented:
70+
return "Unimplemented";
71+
case ResultReason::Infeasible:
72+
return "Infeasible";
6273
}
6374
llvm_unreachable("Unknown ResultReason enum");
6475
}
@@ -142,8 +153,12 @@ class LegalityAnalysis {
142153
std::optional<ResultReason>
143154
notVectorizableBasedOnOpcodesAndTypes(ArrayRef<Value *> Bndl);
144155

156+
ScalarEvolution &SE;
157+
const DataLayout &DL;
158+
145159
public:
146-
LegalityAnalysis() = default;
160+
LegalityAnalysis(ScalarEvolution &SE, const DataLayout &DL)
161+
: SE(SE), DL(DL) {}
147162
/// A LegalityResult factory.
148163
template <typename ResultT, typename... ArgsT>
149164
ResultT &createLegalityResult(ArgsT... Args) {

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ namespace llvm::sandboxir {
2424

2525
class BottomUpVec final : public FunctionPass {
2626
bool Change = false;
27-
LegalityAnalysis Legality;
27+
std::unique_ptr<LegalityAnalysis> Legality;
2828
void vectorizeRec(ArrayRef<Value *> Bndl);
2929
void tryVectorize(ArrayRef<Value *> Seeds);
3030

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
#ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H
1313
#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H
1414

15+
#include "llvm/Analysis/ScalarEvolution.h"
16+
#include "llvm/IR/DataLayout.h"
1517
#include "llvm/SandboxIR/Type.h"
18+
#include "llvm/SandboxIR/Utils.h"
1619

1720
namespace llvm::sandboxir {
1821

@@ -29,6 +32,40 @@ class VecUtils {
2932
static Type *getElementType(Type *Ty) {
3033
return Ty->isVectorTy() ? cast<FixedVectorType>(Ty)->getElementType() : Ty;
3134
}
35+
36+
/// \Returns true if \p I1 and \p I2 are load/stores accessing consecutive
37+
/// memory addresses.
38+
template <typename LoadOrStoreT>
39+
static bool areConsecutive(LoadOrStoreT *I1, LoadOrStoreT *I2,
40+
ScalarEvolution &SE, const DataLayout &DL) {
41+
static_assert(std::is_same<LoadOrStoreT, LoadInst>::value ||
42+
std::is_same<LoadOrStoreT, StoreInst>::value,
43+
"Expected Load or Store!");
44+
auto Diff = Utils::getPointerDiffInBytes(I1, I2, SE);
45+
if (!Diff)
46+
return false;
47+
int ElmBytes = Utils::getNumBits(I1) / 8;
48+
return *Diff == ElmBytes;
49+
}
50+
51+
template <typename LoadOrStoreT>
52+
static bool areConsecutive(ArrayRef<Value *> &Bndl, ScalarEvolution &SE,
53+
const DataLayout &DL) {
54+
static_assert(std::is_same<LoadOrStoreT, LoadInst>::value ||
55+
std::is_same<LoadOrStoreT, StoreInst>::value,
56+
"Expected Load or Store!");
57+
assert(isa<LoadOrStoreT>(Bndl[0]) && "Expected Load or Store!");
58+
auto *LastLS = cast<LoadOrStoreT>(Bndl[0]);
59+
for (Value *V : drop_begin(Bndl)) {
60+
assert(isa<LoadOrStoreT>(V) &&
61+
"Unimplemented: we only support StoreInst!");
62+
auto *LS = cast<LoadOrStoreT>(V);
63+
if (!VecUtils::areConsecutive(LastLS, LS, SE, DL))
64+
return false;
65+
LastLS = LS;
66+
}
67+
return true;
68+
}
3269
};
3370

3471
} // namespace llvm::sandboxir

llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,109 @@ LegalityAnalysis::notVectorizableBasedOnOpcodesAndTypes(
7070
}
7171
}
7272

73-
// TODO: Missing checks
73+
// Now we need to do further checks for specific opcodes.
74+
switch (Opcode) {
75+
case Instruction::Opcode::ZExt:
76+
case Instruction::Opcode::SExt:
77+
case Instruction::Opcode::FPToUI:
78+
case Instruction::Opcode::FPToSI:
79+
case Instruction::Opcode::FPExt:
80+
case Instruction::Opcode::PtrToInt:
81+
case Instruction::Opcode::IntToPtr:
82+
case Instruction::Opcode::SIToFP:
83+
case Instruction::Opcode::UIToFP:
84+
case Instruction::Opcode::Trunc:
85+
case Instruction::Opcode::FPTrunc:
86+
case Instruction::Opcode::BitCast: {
87+
// We have already checked that they are of the same opcode.
88+
assert(all_of(Bndl,
89+
[Opcode](Value *V) {
90+
return cast<Instruction>(V)->getOpcode() == Opcode;
91+
}) &&
92+
"Different opcodes, should have early returned!");
93+
// But for these opcodes we should also check the operand type.
94+
Type *FromTy0 = Utils::getExpectedType(I0->getOperand(0));
95+
if (any_of(drop_begin(Bndl), [FromTy0](Value *V) {
96+
return Utils::getExpectedType(cast<User>(V)->getOperand(0)) !=
97+
FromTy0;
98+
}))
99+
return ResultReason::DiffTypes;
100+
return std::nullopt;
101+
}
102+
case Instruction::Opcode::FCmp:
103+
case Instruction::Opcode::ICmp: {
104+
// We need the same predicate..
105+
auto Pred0 = cast<CmpInst>(I0)->getPredicate();
106+
bool Same = all_of(Bndl, [Pred0](Value *V) {
107+
return cast<CmpInst>(V)->getPredicate() == Pred0;
108+
});
109+
if (Same)
110+
return std::nullopt;
111+
return ResultReason::DiffOpcodes;
112+
}
113+
case Instruction::Opcode::Select:
114+
case Instruction::Opcode::FNeg:
115+
case Instruction::Opcode::Add:
116+
case Instruction::Opcode::FAdd:
117+
case Instruction::Opcode::Sub:
118+
case Instruction::Opcode::FSub:
119+
case Instruction::Opcode::Mul:
120+
case Instruction::Opcode::FMul:
121+
case Instruction::Opcode::FRem:
122+
case Instruction::Opcode::UDiv:
123+
case Instruction::Opcode::SDiv:
124+
case Instruction::Opcode::FDiv:
125+
case Instruction::Opcode::URem:
126+
case Instruction::Opcode::SRem:
127+
case Instruction::Opcode::Shl:
128+
case Instruction::Opcode::LShr:
129+
case Instruction::Opcode::AShr:
130+
case Instruction::Opcode::And:
131+
case Instruction::Opcode::Or:
132+
case Instruction::Opcode::Xor:
133+
return std::nullopt;
134+
case Instruction::Opcode::Load:
135+
if (VecUtils::areConsecutive<LoadInst>(Bndl, SE, DL))
136+
return std::nullopt;
137+
return ResultReason::NotConsecutive;
138+
case Instruction::Opcode::Store:
139+
if (VecUtils::areConsecutive<StoreInst>(Bndl, SE, DL))
140+
return std::nullopt;
141+
return ResultReason::NotConsecutive;
142+
case Instruction::Opcode::PHI:
143+
return ResultReason::Unimplemented;
144+
case Instruction::Opcode::Opaque:
145+
return ResultReason::Unimplemented;
146+
case Instruction::Opcode::Br:
147+
case Instruction::Opcode::Ret:
148+
case Instruction::Opcode::AddrSpaceCast:
149+
case Instruction::Opcode::InsertElement:
150+
case Instruction::Opcode::InsertValue:
151+
case Instruction::Opcode::ExtractElement:
152+
case Instruction::Opcode::ExtractValue:
153+
case Instruction::Opcode::ShuffleVector:
154+
case Instruction::Opcode::Call:
155+
case Instruction::Opcode::GetElementPtr:
156+
case Instruction::Opcode::Switch:
157+
return ResultReason::Unimplemented;
158+
case Instruction::Opcode::VAArg:
159+
case Instruction::Opcode::Freeze:
160+
case Instruction::Opcode::Fence:
161+
case Instruction::Opcode::Invoke:
162+
case Instruction::Opcode::CallBr:
163+
case Instruction::Opcode::LandingPad:
164+
case Instruction::Opcode::CatchPad:
165+
case Instruction::Opcode::CleanupPad:
166+
case Instruction::Opcode::CatchRet:
167+
case Instruction::Opcode::CleanupRet:
168+
case Instruction::Opcode::Resume:
169+
case Instruction::Opcode::CatchSwitch:
170+
case Instruction::Opcode::AtomicRMW:
171+
case Instruction::Opcode::AtomicCmpXchg:
172+
case Instruction::Opcode::Alloca:
173+
case Instruction::Opcode::Unreachable:
174+
return ResultReason::Infeasible;
175+
}
74176

75177
return std::nullopt;
76178
}

llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "llvm/ADT/SmallVector.h"
1212
#include "llvm/SandboxIR/Function.h"
1313
#include "llvm/SandboxIR/Instruction.h"
14+
#include "llvm/SandboxIR/Module.h"
1415
#include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h"
1516

1617
namespace llvm::sandboxir {
@@ -40,7 +41,7 @@ static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl,
4041
}
4142

4243
void BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
43-
const auto &LegalityRes = Legality.canVectorize(Bndl);
44+
const auto &LegalityRes = Legality->canVectorize(Bndl);
4445
switch (LegalityRes.getSubclassID()) {
4546
case LegalityResultID::Widen: {
4647
auto *I = cast<Instruction>(Bndl[0]);
@@ -60,6 +61,8 @@ void BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
6061
void BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) { vectorizeRec(Bndl); }
6162

6263
bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
64+
Legality = std::make_unique<LegalityAnalysis>(A.getScalarEvolution(),
65+
F.getParent()->getDataLayout());
6366
Change = false;
6467
// TODO: Start from innermost BBs first
6568
for (auto &BB : F) {

0 commit comments

Comments
 (0)