Skip to content

Commit 31a4d2c

Browse files
authored
[SandboxVec][DAG] Cleanup: Move callback registration from Scheduler to DAG (#116455)
This is a refactoring patch that moves the callback registration for getting notified about new instructions from the scheduler to the DAG. This makes sense from a design and testing point of view: - the DAG should not rely on the scheduler for getting notified - the notifiers don't need to be public - it's easier to test the notifiers directly from within the DAG unit tests
1 parent c526eb8 commit 31a4d2c

File tree

4 files changed

+72
-32
lines changed

4 files changed

+72
-32
lines changed

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,9 @@ class DependencyGraph {
290290
/// The DAG spans across all instructions in this interval.
291291
Interval<Instruction> DAGInterval;
292292

293+
Context *Ctx = nullptr;
294+
std::optional<Context::CallbackID> CreateInstrCB;
295+
293296
std::unique_ptr<BatchAAResults> BatchAA;
294297

295298
enum class DependencyType {
@@ -325,9 +328,24 @@ class DependencyGraph {
325328
/// chain.
326329
void createNewNodes(const Interval<Instruction> &NewInterval);
327330

331+
/// Called by the callbacks when a new instruction \p I has been created.
332+
void notifyCreateInstr(Instruction *I) {
333+
getOrCreateNode(I);
334+
// TODO: Update the dependencies for the new node.
335+
// TODO: Update the MemDGNode chain to include the new node if needed.
336+
}
337+
328338
public:
329-
DependencyGraph(AAResults &AA)
330-
: BatchAA(std::make_unique<BatchAAResults>(AA)) {}
339+
/// This constructor also registers callbacks.
340+
DependencyGraph(AAResults &AA, Context &Ctx)
341+
: Ctx(&Ctx), BatchAA(std::make_unique<BatchAAResults>(AA)) {
342+
CreateInstrCB = Ctx.registerCreateInstrCallback(
343+
[this](Instruction *I) { notifyCreateInstr(I); });
344+
}
345+
~DependencyGraph() {
346+
if (CreateInstrCB)
347+
Ctx->unregisterCreateInstrCallback(*CreateInstrCB);
348+
}
331349

332350
DGNode *getNode(Instruction *I) const {
333351
auto It = InstrToNodeMap.find(I);
@@ -354,11 +372,6 @@ class DependencyGraph {
354372
Interval<Instruction> extend(ArrayRef<Instruction *> Instrs);
355373
/// \Returns the range of instructions included in the DAG.
356374
Interval<Instruction> getInterval() const { return DAGInterval; }
357-
/// Called by the scheduler when a new instruction \p I has been created.
358-
void notifyCreateInstr(Instruction *I) {
359-
getOrCreateNode(I);
360-
// TODO: Update the dependencies for the new node.
361-
}
362375
void clear() {
363376
InstrToNodeMap.clear();
364377
DAGInterval = {};

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,6 @@ class Scheduler {
106106
std::optional<BasicBlock::iterator> ScheduleTopItOpt;
107107
// TODO: This is wasting memory in exchange for fast removal using a raw ptr.
108108
DenseMap<SchedBundle *, std::unique_ptr<SchedBundle>> Bndls;
109-
Context &Ctx;
110-
Context::CallbackID CreateInstrCB;
111109

112110
/// \Returns a scheduling bundle containing \p Instrs.
113111
SchedBundle *createBundle(ArrayRef<Instruction *> Instrs);
@@ -137,11 +135,8 @@ class Scheduler {
137135
Scheduler &operator=(const Scheduler &) = delete;
138136

139137
public:
140-
Scheduler(AAResults &AA, Context &Ctx) : DAG(AA), Ctx(Ctx) {
141-
CreateInstrCB = Ctx.registerCreateInstrCallback(
142-
[this](Instruction *I) { DAG.notifyCreateInstr(I); });
143-
}
144-
~Scheduler() { Ctx.unregisterCreateInstrCallback(CreateInstrCB); }
138+
Scheduler(AAResults &AA, Context &Ctx) : DAG(AA, Ctx) {}
139+
~Scheduler() {}
145140

146141
bool trySchedule(ArrayRef<Instruction *> Instrs);
147142

llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ define void @foo(i8 %v1, ptr %ptr) {
194194
auto *Call = cast<sandboxir::CallInst>(&*It++);
195195
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
196196

197-
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
197+
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
198198
DAG.extend({&*BB->begin(), BB->getTerminator()});
199199
EXPECT_TRUE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(Store)));
200200
EXPECT_TRUE(isa<llvm::sandboxir::MemDGNode>(DAG.getNode(Load)));
@@ -224,7 +224,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
224224
auto *S0 = cast<sandboxir::StoreInst>(&*It++);
225225
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
226226
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
227-
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
227+
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
228228
auto Span = DAG.extend({&*BB->begin(), BB->getTerminator()});
229229
// Check extend().
230230
EXPECT_EQ(Span.top(), &*BB->begin());
@@ -285,7 +285,7 @@ define i8 @foo(i8 %v0, i8 %v1) {
285285
auto *F = Ctx.createFunction(LLVMF);
286286
auto *BB = &*F->begin();
287287
auto It = BB->begin();
288-
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
288+
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
289289
DAG.extend({&*BB->begin(), BB->getTerminator()});
290290

291291
auto *AddN0 = DAG.getNode(cast<sandboxir::BinaryOperator>(&*It++));
@@ -332,7 +332,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
332332
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
333333
[[maybe_unused]] auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
334334

335-
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
335+
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
336336
DAG.extend({&*BB->begin(), BB->getTerminator()});
337337

338338
auto *S0N = cast<sandboxir::MemDGNode>(DAG.getNode(S0));
@@ -366,7 +366,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
366366
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
367367
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
368368

369-
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
369+
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
370370
DAG.extend({&*BB->begin(), BB->getTerminator()});
371371

372372
auto *S0N = cast<sandboxir::MemDGNode>(DAG.getNode(S0));
@@ -436,7 +436,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
436436
sandboxir::Context Ctx(C);
437437
auto *F = Ctx.createFunction(LLVMF);
438438
auto *BB = &*F->begin();
439-
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
439+
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
440440
DAG.extend({&*BB->begin(), BB->getTerminator()});
441441
auto It = BB->begin();
442442
auto *Store0N = cast<sandboxir::MemDGNode>(
@@ -461,7 +461,7 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1, i8 %v0, i8 %v1) {
461461
sandboxir::Context Ctx(C);
462462
auto *F = Ctx.createFunction(LLVMF);
463463
auto *BB = &*F->begin();
464-
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
464+
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
465465
DAG.extend({&*BB->begin(), BB->getTerminator()});
466466
auto It = BB->begin();
467467
auto *Store0N = cast<sandboxir::MemDGNode>(
@@ -487,7 +487,7 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) {
487487
sandboxir::Context Ctx(C);
488488
auto *F = Ctx.createFunction(LLVMF);
489489
auto *BB = &*F->begin();
490-
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
490+
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
491491
DAG.extend({&*BB->begin(), BB->getTerminator()});
492492
auto It = BB->begin();
493493
auto *Ld0N = cast<sandboxir::MemDGNode>(
@@ -512,7 +512,7 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1, i8 %v) {
512512
sandboxir::Context Ctx(C);
513513
auto *F = Ctx.createFunction(LLVMF);
514514
auto *BB = &*F->begin();
515-
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
515+
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
516516
DAG.extend({&*BB->begin(), BB->getTerminator()});
517517
auto It = BB->begin();
518518
auto *Store0N = cast<sandboxir::MemDGNode>(
@@ -542,7 +542,7 @@ define void @foo(float %v1, float %v2) {
542542
auto *F = Ctx.createFunction(LLVMF);
543543
auto *BB = &*F->begin();
544544

545-
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
545+
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
546546
DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()});
547547

548548
auto It = BB->begin();
@@ -574,7 +574,7 @@ define void @foo() {
574574
auto *F = Ctx.createFunction(LLVMF);
575575
auto *BB = &*F->begin();
576576

577-
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
577+
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
578578
DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()});
579579

580580
auto It = BB->begin();
@@ -606,7 +606,7 @@ define void @foo(i8 %v0, i8 %v1, ptr %ptr) {
606606
auto *F = Ctx.createFunction(LLVMF);
607607
auto *BB = &*F->begin();
608608

609-
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
609+
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
610610
DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()});
611611

612612
auto It = BB->begin();
@@ -637,7 +637,7 @@ define void @foo(ptr %ptr) {
637637
auto *F = Ctx.createFunction(LLVMF);
638638
auto *BB = &*F->begin();
639639

640-
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
640+
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
641641
DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()});
642642

643643
auto It = BB->begin();
@@ -664,7 +664,7 @@ define void @foo(ptr %ptr) {
664664
auto *F = Ctx.createFunction(LLVMF);
665665
auto *BB = &*F->begin();
666666

667-
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
667+
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
668668
DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()});
669669

670670
auto It = BB->begin();
@@ -695,7 +695,7 @@ define void @foo() {
695695
auto *F = Ctx.createFunction(LLVMF);
696696
auto *BB = &*F->begin();
697697

698-
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
698+
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
699699
DAG.extend({&*BB->begin(), BB->getTerminator()->getPrevNode()});
700700

701701
auto It = BB->begin();
@@ -728,7 +728,7 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
728728
auto *S3 = cast<sandboxir::StoreInst>(&*It++);
729729
auto *S4 = cast<sandboxir::StoreInst>(&*It++);
730730
auto *S5 = cast<sandboxir::StoreInst>(&*It++);
731-
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
731+
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
732732
{
733733
// Scenario 1: Build new DAG
734734
auto NewIntvl = DAG.extend({S3, S3});
@@ -788,7 +788,7 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
788788

789789
{
790790
// Check UnscheduledSuccs when a node is scheduled
791-
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
791+
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
792792
DAG.extend({S2, S2});
793793
auto *S2N = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
794794
S2N->setScheduled(true);
@@ -798,3 +798,35 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
798798
EXPECT_EQ(S1N->getNumUnscheduledSuccs(), 0u); // S1 is scheduled
799799
}
800800
}
801+
802+
TEST_F(DependencyGraphTest, CreateInstrCallback) {
803+
parseIR(C, R"IR(
804+
define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %arg) {
805+
store i8 %v1, ptr %ptr
806+
store i8 %v2, ptr %ptr
807+
store i8 %v3, ptr %ptr
808+
ret void
809+
}
810+
)IR");
811+
llvm::Function *LLVMF = &*M->getFunction("foo");
812+
sandboxir::Context Ctx(C);
813+
auto *F = Ctx.createFunction(LLVMF);
814+
auto *BB = &*F->begin();
815+
auto It = BB->begin();
816+
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
817+
[[maybe_unused]] auto *S2 = cast<sandboxir::StoreInst>(&*It++);
818+
auto *S3 = cast<sandboxir::StoreInst>(&*It++);
819+
820+
// Check new instruction callback.
821+
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
822+
DAG.extend({S1, S3});
823+
auto *Arg = F->getArg(3);
824+
auto *Ptr = S1->getPointerOperand();
825+
sandboxir::StoreInst *NewS =
826+
sandboxir::StoreInst::create(Arg, Ptr, Align(8), S3->getIterator(),
827+
/*IsVolatile=*/true, Ctx);
828+
auto *NewSN = DAG.getNode(NewS);
829+
EXPECT_TRUE(NewSN != nullptr);
830+
// TODO: Check the dependencies to/from NewSN after they land.
831+
// TODO: Check the MemDGNode chain.
832+
}

llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
7070
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
7171
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
7272

73-
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
73+
sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
7474
DAG.extend({&*BB->begin(), BB->getTerminator()});
7575
auto *SN0 = DAG.getNode(S0);
7676
auto *SN1 = DAG.getNode(S1);

0 commit comments

Comments
 (0)