Skip to content

Commit 2fd4e63

Browse files
committed
[SandboxVec] Notify scheduler about new instructions
This patch registers the "createInstr" callback that notifies the scheduler about newly created instructions. This guarantees that all newly created instructions have a corresponding DAG node associated with them. Without this the pass crashes when the scheduler encounters the newly created vector instructions. This patch also changes the lifetime of the sandboxir Ctx variable in the SandboxVectorizer pass. It needs to be destroyed after the passes get destroyed. Without this change when components like the Scheduler get destroyed Ctx will have already been freed, which is not legal.
1 parent d047488 commit 2fd4e63

File tree

9 files changed

+72
-17
lines changed

9 files changed

+72
-17
lines changed

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,11 @@ class DependencyGraph {
345345
Interval<Instruction> extend(ArrayRef<Instruction *> Instrs);
346346
/// \Returns the range of instructions included in the DAG.
347347
Interval<Instruction> getInterval() const { return DAGInterval; }
348+
/// Called by the scheduler when a new instruction \p I has been created.
349+
void notifyCreateInstr(Instruction *I) {
350+
getOrCreateNode(I);
351+
// TODO: Update the dependencies for the new node.
352+
}
348353
#ifndef NDEBUG
349354
void print(raw_ostream &OS) const;
350355
LLVM_DUMP_METHOD void dump() const;

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,9 @@ class LegalityAnalysis {
162162
const DataLayout &DL;
163163

164164
public:
165-
LegalityAnalysis(AAResults &AA, ScalarEvolution &SE, const DataLayout &DL)
166-
: Sched(AA), SE(SE), DL(DL) {}
165+
LegalityAnalysis(AAResults &AA, ScalarEvolution &SE, const DataLayout &DL,
166+
Context &Ctx)
167+
: Sched(AA, Ctx), SE(SE), DL(DL) {}
167168
/// A LegalityResult factory.
168169
template <typename ResultT, typename... ArgsT>
169170
ResultT &createLegalityResult(ArgsT... Args) {

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "llvm/Analysis/AliasAnalysis.h"
1414
#include "llvm/Analysis/ScalarEvolution.h"
1515
#include "llvm/IR/PassManager.h"
16+
#include "llvm/SandboxIR/Context.h"
1617
#include "llvm/SandboxIR/PassManager.h"
1718

1819
namespace llvm {
@@ -24,6 +25,8 @@ class SandboxVectorizerPass : public PassInfoMixin<SandboxVectorizerPass> {
2425
AAResults *AA = nullptr;
2526
ScalarEvolution *SE = nullptr;
2627

28+
std::unique_ptr<sandboxir::Context> Ctx;
29+
2730
// A pipeline of SandboxIR function passes run by the vectorizer.
2831
sandboxir::FunctionPassManager FPM;
2932

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ class Scheduler {
9595
DependencyGraph DAG;
9696
std::optional<BasicBlock::iterator> ScheduleTopItOpt;
9797
SmallVector<std::unique_ptr<SchedBundle>> Bndls;
98+
Context &Ctx;
99+
Context::CallbackID CreateInstrCB;
98100

99101
/// \Returns a scheduling bundle containing \p Instrs.
100102
SchedBundle *createBundle(ArrayRef<Instruction *> Instrs);
@@ -110,8 +112,11 @@ class Scheduler {
110112
Scheduler &operator=(const Scheduler &) = delete;
111113

112114
public:
113-
Scheduler(AAResults &AA) : DAG(AA) {}
114-
~Scheduler() {}
115+
Scheduler(AAResults &AA, Context &Ctx) : DAG(AA), Ctx(Ctx) {
116+
CreateInstrCB = Ctx.registerCreateInstrCallback(
117+
[this](Instruction *I) { DAG.notifyCreateInstr(I); });
118+
}
119+
~Scheduler() { Ctx.unregisterCreateInstrCallback(CreateInstrCB); }
115120

116121
bool trySchedule(ArrayRef<Instruction *> Instrs);
117122

llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,6 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
181181
}
182182
NewVec = createVectorInstr(Bndl, VecOperands);
183183

184-
// TODO: Notify DAG/Scheduler about new instruction
185-
186184
// TODO: Collect potentially dead instructions.
187185
break;
188186
}
@@ -198,7 +196,8 @@ void BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) { vectorizeRec(Bndl); }
198196

199197
bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
200198
Legality = std::make_unique<LegalityAnalysis>(
201-
A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout());
199+
A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(),
200+
F.getContext());
202201
Change = false;
203202
// TODO: Start from innermost BBs first
204203
for (auto &BB : F) {

llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ PreservedAnalyses SandboxVectorizerPass::run(Function &F,
6464
}
6565

6666
bool SandboxVectorizerPass::runImpl(Function &LLVMF) {
67+
if (Ctx == nullptr)
68+
Ctx = std::make_unique<sandboxir::Context>(LLVMF.getContext());
69+
6770
if (PrintPassPipeline) {
6871
FPM.printPipeline(outs());
6972
return false;
@@ -82,8 +85,7 @@ bool SandboxVectorizerPass::runImpl(Function &LLVMF) {
8285
}
8386

8487
// Create SandboxIR for LLVMF and run BottomUpVec on it.
85-
sandboxir::Context Ctx(LLVMF.getContext());
86-
sandboxir::Function &F = *Ctx.createFunction(&LLVMF);
88+
sandboxir::Function &F = *Ctx->createFunction(&LLVMF);
8789
sandboxir::Analyses A(*AA, *SE);
8890
return FPM.runOnFunction(F, A);
8991
}

llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,46 @@ define void @store_fpext_load(ptr %ptr) {
5555
ret void
5656
}
5757

58-
; TODO: Test store_zext_fcmp_load once we implement scheduler callbacks and legality diamond check
58+
define void @store_fcmp_zext_load(ptr %ptr) {
59+
; CHECK-LABEL: define void @store_fcmp_zext_load(
60+
; CHECK-SAME: ptr [[PTR:%.*]]) {
61+
; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
62+
; CHECK-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1
63+
; CHECK-NEXT: [[PTRB0:%.*]] = getelementptr i32, ptr [[PTR]], i32 0
64+
; CHECK-NEXT: [[PTRB1:%.*]] = getelementptr i32, ptr [[PTR]], i32 1
65+
; CHECK-NEXT: [[LDB0:%.*]] = load float, ptr [[PTR0]], align 4
66+
; CHECK-NEXT: [[LDB1:%.*]] = load float, ptr [[PTR1]], align 4
67+
; CHECK-NEXT: [[VECL1:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
68+
; CHECK-NEXT: [[LDA0:%.*]] = load float, ptr [[PTR0]], align 4
69+
; CHECK-NEXT: [[LDA1:%.*]] = load float, ptr [[PTR1]], align 4
70+
; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
71+
; CHECK-NEXT: [[FCMP0:%.*]] = fcmp ogt float [[LDA0]], [[LDB0]]
72+
; CHECK-NEXT: [[FCMP1:%.*]] = fcmp ogt float [[LDA1]], [[LDB1]]
73+
; CHECK-NEXT: [[VCMP:%.*]] = fcmp ogt <2 x float> [[VECL]], [[VECL1]]
74+
; CHECK-NEXT: [[ZEXT0:%.*]] = zext i1 [[FCMP0]] to i32
75+
; CHECK-NEXT: [[ZEXT1:%.*]] = zext i1 [[FCMP1]] to i32
76+
; CHECK-NEXT: [[VCAST:%.*]] = zext <2 x i1> [[VCMP]] to <2 x i32>
77+
; CHECK-NEXT: store i32 [[ZEXT0]], ptr [[PTRB0]], align 4
78+
; CHECK-NEXT: store i32 [[ZEXT1]], ptr [[PTRB1]], align 4
79+
; CHECK-NEXT: store <2 x i32> [[VCAST]], ptr [[PTRB0]], align 4
80+
; CHECK-NEXT: ret void
81+
;
82+
%ptr0 = getelementptr float, ptr %ptr, i32 0
83+
%ptr1 = getelementptr float, ptr %ptr, i32 1
84+
%ptrb0 = getelementptr i32, ptr %ptr, i32 0
85+
%ptrb1 = getelementptr i32, ptr %ptr, i32 1
86+
%ldB0 = load float, ptr %ptr0
87+
%ldB1 = load float, ptr %ptr1
88+
%ldA0 = load float, ptr %ptr0
89+
%ldA1 = load float, ptr %ptr1
90+
%fcmp0 = fcmp ogt float %ldA0, %ldB0
91+
%fcmp1 = fcmp ogt float %ldA1, %ldB1
92+
%zext0 = zext i1 %fcmp0 to i32
93+
%zext1 = zext i1 %fcmp1 to i32
94+
store i32 %zext0, ptr %ptrb0
95+
store i32 %zext1, ptr %ptrb1
96+
ret void
97+
}
5998

6099
; TODO: Test store_fadd_load once we implement scheduler callbacks and legality diamond check
61100

llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
110110
auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
111111
auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);
112112

113-
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL);
113+
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx);
114114
const auto &Result =
115115
Legality.canVectorize({St0, St1}, /*SkipScheduling=*/true);
116116
EXPECT_TRUE(isa<sandboxir::Widen>(Result));
@@ -228,7 +228,7 @@ define void @foo(ptr %ptr) {
228228
auto *St0 = cast<sandboxir::StoreInst>(&*It++);
229229
auto *St1 = cast<sandboxir::StoreInst>(&*It++);
230230

231-
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL);
231+
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx);
232232
{
233233
// Can vectorize St0,St1.
234234
const auto &Result = Legality.canVectorize({St0, St1});
@@ -262,7 +262,8 @@ define void @foo() {
262262
return Buff == ExpectedStr;
263263
};
264264

265-
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL);
265+
sandboxir::Context Ctx(C);
266+
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx);
266267
EXPECT_TRUE(
267268
Matches(Legality.createLegalityResult<sandboxir::Widen>(), "Widen"));
268269
EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(

llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,20 +156,20 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
156156

157157
{
158158
// Schedule all instructions in sequence.
159-
sandboxir::Scheduler Sched(getAA(*LLVMF));
159+
sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
160160
EXPECT_TRUE(Sched.trySchedule({Ret}));
161161
EXPECT_TRUE(Sched.trySchedule({S1}));
162162
EXPECT_TRUE(Sched.trySchedule({S0}));
163163
}
164164
{
165165
// Skip instructions.
166-
sandboxir::Scheduler Sched(getAA(*LLVMF));
166+
sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
167167
EXPECT_TRUE(Sched.trySchedule({Ret}));
168168
EXPECT_TRUE(Sched.trySchedule({S0}));
169169
}
170170
{
171171
// Try invalid scheduling
172-
sandboxir::Scheduler Sched(getAA(*LLVMF));
172+
sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
173173
EXPECT_TRUE(Sched.trySchedule({Ret}));
174174
EXPECT_TRUE(Sched.trySchedule({S0}));
175175
EXPECT_FALSE(Sched.trySchedule({S1}));
@@ -197,7 +197,7 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) {
197197
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
198198
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
199199

200-
sandboxir::Scheduler Sched(getAA(*LLVMF));
200+
sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
201201
EXPECT_TRUE(Sched.trySchedule({Ret}));
202202
EXPECT_TRUE(Sched.trySchedule({S0, S1}));
203203
EXPECT_TRUE(Sched.trySchedule({L0, L1}));

0 commit comments

Comments
 (0)