Skip to content

Commit f7ef7b2

Browse files
authored
[SandboxVec][Scheduler] Implement rescheduling (#115220)
This patch adds support for re-scheduling already scheduled instructions. For now this will clear and rebuild the DAG, and will reschedule the code using the new DAG.
1 parent ae6dbed commit f7ef7b2

File tree

7 files changed

+226
-24
lines changed

7 files changed

+226
-24
lines changed

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ namespace llvm::sandboxir {
3333

3434
class DependencyGraph;
3535
class MemDGNode;
36+
class SchedBundle;
3637

3738
/// SubclassIDs for isa/dyn_cast etc.
3839
enum class DGNodeID {
@@ -100,6 +101,12 @@ class DGNode {
100101
unsigned UnscheduledSuccs = 0;
101102
/// This is true if this node has been scheduled.
102103
bool Scheduled = false;
104+
/// The scheduler bundle that this node belongs to.
105+
SchedBundle *SB = nullptr;
106+
107+
void setSchedBundle(SchedBundle &SB) { this->SB = &SB; }
108+
void clearSchedBundle() { this->SB = nullptr; }
109+
friend class SchedBundle; // For setSchedBundle(), clearSchedBundle().
103110

104111
DGNode(Instruction *I, DGNodeID ID) : I(I), SubclassID(ID) {}
105112
friend class MemDGNode; // For constructor.
@@ -122,6 +129,8 @@ class DGNode {
122129
/// \Returns true if this node has been scheduled.
123130
bool scheduled() const { return Scheduled; }
124131
void setScheduled(bool NewVal) { Scheduled = NewVal; }
132+
/// \Returns the scheduling bundle that this node belongs to, or nullptr.
133+
SchedBundle *getSchedBundle() const { return SB; }
125134
/// \Returns true if this is before \p Other in program order.
126135
bool comesBefore(const DGNode *Other) { return I->comesBefore(Other->I); }
127136
using iterator = PredIterator;
@@ -350,6 +359,10 @@ class DependencyGraph {
350359
getOrCreateNode(I);
351360
// TODO: Update the dependencies for the new node.
352361
}
362+
void clear() {
363+
InstrToNodeMap.clear();
364+
DAGInterval = {};
365+
}
353366
#ifndef NDEBUG
354367
void print(raw_ostream &OS) const;
355368
LLVM_DUMP_METHOD void dump() const;

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class ReadyListContainer {
5353
return Back;
5454
}
5555
bool empty() const { return List.empty(); }
56+
void clear() { List = {}; }
5657
#ifndef NDEBUG
5758
void dump(raw_ostream &OS) const;
5859
LLVM_DUMP_METHOD void dump() const;
@@ -70,7 +71,16 @@ class SchedBundle {
7071

7172
public:
7273
SchedBundle() = default;
73-
SchedBundle(ContainerTy &&Nodes) : Nodes(std::move(Nodes)) {}
74+
SchedBundle(ContainerTy &&Nodes) : Nodes(std::move(Nodes)) {
75+
for (auto *N : this->Nodes)
76+
N->setSchedBundle(*this);
77+
}
78+
~SchedBundle() {
79+
for (auto *N : this->Nodes)
80+
N->clearSchedBundle();
81+
}
82+
bool empty() const { return Nodes.empty(); }
83+
DGNode *back() const { return Nodes.back(); }
7484
using iterator = ContainerTy::iterator;
7585
using const_iterator = ContainerTy::const_iterator;
7686
iterator begin() { return Nodes.begin(); }
@@ -94,19 +104,34 @@ class Scheduler {
94104
ReadyListContainer ReadyList;
95105
DependencyGraph DAG;
96106
std::optional<BasicBlock::iterator> ScheduleTopItOpt;
97-
SmallVector<std::unique_ptr<SchedBundle>> Bndls;
107+
// TODO: This is wasting memory in exchange for fast removal using a raw ptr.
108+
DenseMap<SchedBundle *, std::unique_ptr<SchedBundle>> Bndls;
98109
Context &Ctx;
99110
Context::CallbackID CreateInstrCB;
100111

101112
/// \Returns a scheduling bundle containing \p Instrs.
102113
SchedBundle *createBundle(ArrayRef<Instruction *> Instrs);
114+
void eraseBundle(SchedBundle *SB);
103115
/// Schedule nodes until we can schedule \p Instrs back-to-back.
104116
bool tryScheduleUntil(ArrayRef<Instruction *> Instrs);
105117
/// Schedules all nodes in \p Bndl, marks them as scheduled, updates the
106118
/// UnscheduledSuccs counter of all dependency predecessors, and adds any of
107119
/// them that become ready to the ready list.
108120
void scheduleAndUpdateReadyList(SchedBundle &Bndl);
109-
121+
/// The scheduling state of the instructions in the bundle.
122+
enum class BndlSchedState {
123+
NoneScheduled, ///> No instruction in the bundle was previously scheduled.
124+
PartiallyOrDifferentlyScheduled, ///> Only some of the instrs in the bundle
125+
/// were previously scheduled, or all of
126+
/// them were but not in the same
127+
/// SchedBundle.
128+
FullyScheduled, ///> All instrs in the bundle were previously scheduled and
129+
/// were in the same SchedBundle.
130+
};
131+
/// \Returns whether none/some/all of \p Instrs have been scheduled.
132+
BndlSchedState getBndlSchedState(ArrayRef<Instruction *> Instrs) const;
133+
/// Destroy the top-most part of the schedule that includes \p Instrs.
134+
void trimSchedule(ArrayRef<Instruction *> Instrs);
110135
/// Disable copies.
111136
Scheduler(const Scheduler &) = delete;
112137
Scheduler &operator=(const Scheduler &) = delete;

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,14 @@ class VecUtils {
100100
}
101101
return FixedVectorType::get(ElemTy, NumElts);
102102
}
103+
static Instruction *getLowest(ArrayRef<Instruction *> Instrs) {
104+
Instruction *LowestI = Instrs.front();
105+
for (auto *I : drop_begin(Instrs)) {
106+
if (LowestI->comesBefore(I))
107+
LowestI = I;
108+
}
109+
return LowestI;
110+
}
103111
};
104112

105113
} // namespace llvm::sandboxir

llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp

Lines changed: 79 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h"
10+
#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
1011

1112
namespace llvm::sandboxir {
1213

@@ -95,10 +96,12 @@ SchedBundle *Scheduler::createBundle(ArrayRef<Instruction *> Instrs) {
9596
Nodes.push_back(DAG.getNode(I));
9697
auto BndlPtr = std::make_unique<SchedBundle>(std::move(Nodes));
9798
auto *Bndl = BndlPtr.get();
98-
Bndls.push_back(std::move(BndlPtr));
99+
Bndls[Bndl] = std::move(BndlPtr);
99100
return Bndl;
100101
}
101102

103+
void Scheduler::eraseBundle(SchedBundle *SB) { Bndls.erase(SB); }
104+
102105
bool Scheduler::tryScheduleUntil(ArrayRef<Instruction *> Instrs) {
103106
// Use a set of instructions, instead of `Instrs` for fast lookups.
104107
DenseSet<Instruction *> InstrsToDefer(Instrs.begin(), Instrs.end());
@@ -133,29 +136,88 @@ bool Scheduler::tryScheduleUntil(ArrayRef<Instruction *> Instrs) {
133136
return false;
134137
}
135138

139+
Scheduler::BndlSchedState
140+
Scheduler::getBndlSchedState(ArrayRef<Instruction *> Instrs) const {
141+
assert(!Instrs.empty() && "Expected non-empty bundle");
142+
bool PartiallyScheduled = false;
143+
bool FullyScheduled = true;
144+
for (auto *I : Instrs) {
145+
auto *N = DAG.getNode(I);
146+
if (N != nullptr && N->scheduled())
147+
PartiallyScheduled = true;
148+
else
149+
FullyScheduled = false;
150+
}
151+
if (FullyScheduled) {
152+
// If not all instrs in the bundle are in the same SchedBundle then this
153+
// should be considered as partially-scheduled, because we will need to
154+
// re-schedule.
155+
SchedBundle *SB = DAG.getNode(Instrs[0])->getSchedBundle();
156+
assert(SB != nullptr && "FullyScheduled assumes that there is an SB!");
157+
if (any_of(drop_begin(Instrs), [this, SB](sandboxir::Value *SBV) {
158+
return DAG.getNode(cast<sandboxir::Instruction>(SBV))
159+
->getSchedBundle() != SB;
160+
}))
161+
FullyScheduled = false;
162+
}
163+
return FullyScheduled ? BndlSchedState::FullyScheduled
164+
: PartiallyScheduled ? BndlSchedState::PartiallyOrDifferentlyScheduled
165+
: BndlSchedState::NoneScheduled;
166+
}
167+
168+
void Scheduler::trimSchedule(ArrayRef<Instruction *> Instrs) {
169+
Instruction *TopI = &*ScheduleTopItOpt.value();
170+
Instruction *LowestI = VecUtils::getLowest(Instrs);
171+
// Destroy the schedule bundles from LowestI all the way to the top.
172+
for (auto *I = LowestI, *E = TopI->getPrevNode(); I != E;
173+
I = I->getPrevNode()) {
174+
auto *N = DAG.getNode(I);
175+
if (auto *SB = N->getSchedBundle())
176+
eraseBundle(SB);
177+
}
178+
// TODO: For now we clear the DAG. Trim view once it gets implemented.
179+
Bndls.clear();
180+
DAG.clear();
181+
182+
// Since we are scheduling NewRegion from scratch, we clear the ready lists.
183+
// The nodes currently in the list may not be ready after clearing the View.
184+
ReadyList.clear();
185+
}
186+
136187
bool Scheduler::trySchedule(ArrayRef<Instruction *> Instrs) {
137188
assert(all_of(drop_begin(Instrs),
138189
[Instrs](Instruction *I) {
139190
return I->getParent() == (*Instrs.begin())->getParent();
140191
}) &&
141192
"Instrs not in the same BB!");
142-
// Extend the DAG to include Instrs.
143-
Interval<Instruction> Extension = DAG.extend(Instrs);
144-
// TODO: Set the window of the DAG that we are interested in.
145-
// We start scheduling at the bottom instr of Instrs.
146-
auto getBottomI = [](ArrayRef<Instruction *> Instrs) -> Instruction * {
147-
return *min_element(Instrs,
148-
[](auto *I1, auto *I2) { return I1->comesBefore(I2); });
149-
};
150-
ScheduleTopItOpt = std::next(getBottomI(Instrs)->getIterator());
151-
// Add nodes to ready list.
152-
for (auto &I : Extension) {
153-
auto *N = DAG.getNode(&I);
154-
if (N->ready())
155-
ReadyList.insert(N);
193+
auto SchedState = getBndlSchedState(Instrs);
194+
switch (SchedState) {
195+
case BndlSchedState::FullyScheduled:
196+
// Nothing to do.
197+
return true;
198+
case BndlSchedState::PartiallyOrDifferentlyScheduled:
199+
// If one or more instrs are already scheduled we need to destroy the
200+
// top-most part of the schedule that includes the instrs in the bundle and
201+
// re-schedule.
202+
trimSchedule(Instrs);
203+
[[fallthrough]];
204+
case BndlSchedState::NoneScheduled: {
205+
// TODO: Set the window of the DAG that we are interested in.
206+
// We start scheduling at the bottom instr of Instrs.
207+
ScheduleTopItOpt = std::next(VecUtils::getLowest(Instrs)->getIterator());
208+
209+
// Extend the DAG to include Instrs.
210+
Interval<Instruction> Extension = DAG.extend(Instrs);
211+
// Add nodes to ready list.
212+
for (auto &I : Extension) {
213+
auto *N = DAG.getNode(&I);
214+
if (N->ready())
215+
ReadyList.insert(N);
216+
}
217+
// Try schedule all nodes until we can schedule Instrs back-to-back.
218+
return tryScheduleUntil(Instrs);
219+
}
156220
}
157-
// Try schedule all nodes until we can schedule Instrs back-to-back.
158-
return tryScheduleUntil(Instrs);
159221
}
160222

161223
#ifndef NDEBUG

llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,37 @@ define void @store_fcmp_zext_load(ptr %ptr) {
9696
ret void
9797
}
9898

99-
; TODO: Test store_fadd_load once we implement scheduler callbacks and legality diamond check
99+
define void @store_fadd_load(ptr %ptr) {
100+
; CHECK-LABEL: define void @store_fadd_load(
101+
; CHECK-SAME: ptr [[PTR:%.*]]) {
102+
; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
103+
; CHECK-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1
104+
; CHECK-NEXT: [[LDA0:%.*]] = load float, ptr [[PTR0]], align 4
105+
; CHECK-NEXT: [[LDA1:%.*]] = load float, ptr [[PTR1]], align 4
106+
; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
107+
; CHECK-NEXT: [[LDB0:%.*]] = load float, ptr [[PTR0]], align 4
108+
; CHECK-NEXT: [[LDB1:%.*]] = load float, ptr [[PTR1]], align 4
109+
; CHECK-NEXT: [[VECL1:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
110+
; CHECK-NEXT: [[FADD0:%.*]] = fadd float [[LDA0]], [[LDB0]]
111+
; CHECK-NEXT: [[FADD1:%.*]] = fadd float [[LDA1]], [[LDB1]]
112+
; CHECK-NEXT: [[VEC:%.*]] = fadd <2 x float> [[VECL]], [[VECL1]]
113+
; CHECK-NEXT: store float [[FADD0]], ptr [[PTR0]], align 4
114+
; CHECK-NEXT: store float [[FADD1]], ptr [[PTR1]], align 4
115+
; CHECK-NEXT: store <2 x float> [[VEC]], ptr [[PTR0]], align 4
116+
; CHECK-NEXT: ret void
117+
;
118+
%ptr0 = getelementptr float, ptr %ptr, i32 0
119+
%ptr1 = getelementptr float, ptr %ptr, i32 1
120+
%ldA0 = load float, ptr %ptr0
121+
%ldA1 = load float, ptr %ptr1
122+
%ldB0 = load float, ptr %ptr0
123+
%ldB1 = load float, ptr %ptr1
124+
%fadd0 = fadd float %ldA0, %ldB0
125+
%fadd1 = fadd float %ldA1, %ldB1
126+
store float %fadd0, ptr %ptr0
127+
store float %fadd1, ptr %ptr1
128+
ret void
129+
}
100130

101131
define void @store_fneg_load(ptr %ptr) {
102132
; CHECK-LABEL: define void @store_fneg_load(

llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,10 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
168168
EXPECT_TRUE(Sched.trySchedule({S0}));
169169
}
170170
{
171-
// Try invalid scheduling
171+
// Try invalid scheduling. Dependency S0->S1.
172172
sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
173173
EXPECT_TRUE(Sched.trySchedule({Ret}));
174-
EXPECT_TRUE(Sched.trySchedule({S0}));
175-
EXPECT_FALSE(Sched.trySchedule({S1}));
174+
EXPECT_FALSE(Sched.trySchedule({S0, S1}));
176175
}
177176
}
178177

@@ -202,3 +201,39 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) {
202201
EXPECT_TRUE(Sched.trySchedule({S0, S1}));
203202
EXPECT_TRUE(Sched.trySchedule({L0, L1}));
204203
}
204+
205+
TEST_F(SchedulerTest, RescheduleAlreadyScheduled) {
206+
parseIR(C, R"IR(
207+
define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) {
208+
%ld0 = load i8, ptr %ptr0
209+
%ld1 = load i8, ptr %ptr1
210+
%add0 = add i8 %ld0, %ld0
211+
%add1 = add i8 %ld1, %ld1
212+
store i8 %add0, ptr %ptr0
213+
store i8 %add1, ptr %ptr1
214+
ret void
215+
}
216+
)IR");
217+
llvm::Function *LLVMF = &*M->getFunction("foo");
218+
sandboxir::Context Ctx(C);
219+
auto *F = Ctx.createFunction(LLVMF);
220+
auto *BB = &*F->begin();
221+
auto It = BB->begin();
222+
auto *L0 = cast<sandboxir::LoadInst>(&*It++);
223+
auto *L1 = cast<sandboxir::LoadInst>(&*It++);
224+
auto *Add0 = cast<sandboxir::BinaryOperator>(&*It++);
225+
auto *Add1 = cast<sandboxir::BinaryOperator>(&*It++);
226+
auto *S0 = cast<sandboxir::StoreInst>(&*It++);
227+
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
228+
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
229+
230+
sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
231+
EXPECT_TRUE(Sched.trySchedule({Ret}));
232+
EXPECT_TRUE(Sched.trySchedule({S0, S1}));
233+
EXPECT_TRUE(Sched.trySchedule({L0, L1}));
234+
// At this point Add0 and Add1 should have been individually scheduled
235+
// as single bundles.
236+
// Check if rescheduling works.
237+
EXPECT_TRUE(Sched.trySchedule({Add0, Add1}));
238+
EXPECT_TRUE(Sched.trySchedule({L0, L1}));
239+
}

llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,3 +410,32 @@ TEST_F(VecUtilsTest, GetWideType) {
410410
auto *Int32X8Ty = sandboxir::FixedVectorType::get(Int32Ty, 8);
411411
EXPECT_EQ(sandboxir::VecUtils::getWideType(Int32X4Ty, 2), Int32X8Ty);
412412
}
413+
414+
TEST_F(VecUtilsTest, GetLowest) {
415+
parseIR(R"IR(
416+
define void @foo(i8 %v) {
417+
bb0:
418+
%A = add i8 %v, %v
419+
%B = add i8 %v, %v
420+
%C = add i8 %v, %v
421+
ret void
422+
}
423+
)IR");
424+
Function &LLVMF = *M->getFunction("foo");
425+
426+
sandboxir::Context Ctx(C);
427+
auto &F = *Ctx.createFunction(&LLVMF);
428+
auto &BB = *F.begin();
429+
auto It = BB.begin();
430+
auto *IA = &*It++;
431+
auto *IB = &*It++;
432+
auto *IC = &*It++;
433+
SmallVector<sandboxir::Instruction *> ABC({IA, IB, IC});
434+
EXPECT_EQ(sandboxir::VecUtils::getLowest(ABC), IC);
435+
SmallVector<sandboxir::Instruction *> ACB({IA, IC, IB});
436+
EXPECT_EQ(sandboxir::VecUtils::getLowest(ACB), IC);
437+
SmallVector<sandboxir::Instruction *> CAB({IC, IA, IB});
438+
EXPECT_EQ(sandboxir::VecUtils::getLowest(CAB), IC);
439+
SmallVector<sandboxir::Instruction *> CBA({IC, IB, IA});
440+
EXPECT_EQ(sandboxir::VecUtils::getLowest(CBA), IC);
441+
}

0 commit comments

Comments
 (0)