Skip to content

[SandboxVec] Notify scheduler about new instructions #115102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,11 @@ class DependencyGraph {
Interval<Instruction> extend(ArrayRef<Instruction *> Instrs);
/// \Returns the range of instructions included in the DAG.
Interval<Instruction> getInterval() const { return DAGInterval; }
/// Called by the scheduler when a new instruction \p I has been created.
void notifyCreateInstr(Instruction *I) {
getOrCreateNode(I);
// TODO: Update the dependencies for the new node.
}
#ifndef NDEBUG
void print(raw_ostream &OS) const;
LLVM_DUMP_METHOD void dump() const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,9 @@ class LegalityAnalysis {
const DataLayout &DL;

public:
LegalityAnalysis(AAResults &AA, ScalarEvolution &SE, const DataLayout &DL)
: Sched(AA), SE(SE), DL(DL) {}
LegalityAnalysis(AAResults &AA, ScalarEvolution &SE, const DataLayout &DL,
Context &Ctx)
: Sched(AA, Ctx), SE(SE), DL(DL) {}
/// A LegalityResult factory.
template <typename ResultT, typename... ArgsT>
ResultT &createLegalityResult(ArgsT... Args) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/IR/PassManager.h"
#include "llvm/SandboxIR/Context.h"
#include "llvm/SandboxIR/PassManager.h"

namespace llvm {
Expand All @@ -24,6 +25,8 @@ class SandboxVectorizerPass : public PassInfoMixin<SandboxVectorizerPass> {
AAResults *AA = nullptr;
ScalarEvolution *SE = nullptr;

std::unique_ptr<sandboxir::Context> Ctx;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be a member?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it does need to be a member, here is why:
If we create the Context within the scope of runImpl(), like we used to, then Ctx goes out of scope once we return from runImpl(), but the problem is that the nested members of the SandboxVectorizerPass, like the Scheduler, get destroyed after the scope goes out of scope. This means that if we try to unregister the scheduler's callbacks within the Scheduler destructor we access freed memory.

Not sure how else we could fix this in a clean way. Having a member context seems to be the easiest fix.


// A pipeline of SandboxIR function passes run by the vectorizer.
sandboxir::FunctionPassManager FPM;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class Scheduler {
DependencyGraph DAG;
std::optional<BasicBlock::iterator> ScheduleTopItOpt;
SmallVector<std::unique_ptr<SchedBundle>> Bndls;
Context &Ctx;
Context::CallbackID CreateInstrCB;

/// \Returns a scheduling bundle containing \p Instrs.
SchedBundle *createBundle(ArrayRef<Instruction *> Instrs);
Expand All @@ -110,8 +112,11 @@ class Scheduler {
Scheduler &operator=(const Scheduler &) = delete;

public:
Scheduler(AAResults &AA) : DAG(AA) {}
~Scheduler() {}
Scheduler(AAResults &AA, Context &Ctx) : DAG(AA), Ctx(Ctx) {
CreateInstrCB = Ctx.registerCreateInstrCallback(
[this](Instruction *I) { DAG.notifyCreateInstr(I); });
}
~Scheduler() { Ctx.unregisterCreateInstrCallback(CreateInstrCB); }

bool trySchedule(ArrayRef<Instruction *> Instrs);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,6 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
}
NewVec = createVectorInstr(Bndl, VecOperands);

// TODO: Notify DAG/Scheduler about new instruction

// TODO: Collect potentially dead instructions.
break;
}
Expand All @@ -202,7 +200,8 @@ bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {

bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
Legality = std::make_unique<LegalityAnalysis>(
A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout());
A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(),
F.getContext());
Change = false;
// TODO: Start from innermost BBs first
for (auto &BB : F) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ PreservedAnalyses SandboxVectorizerPass::run(Function &F,
}

bool SandboxVectorizerPass::runImpl(Function &LLVMF) {
if (Ctx == nullptr)
Ctx = std::make_unique<sandboxir::Context>(LLVMF.getContext());

if (PrintPassPipeline) {
FPM.printPipeline(outs());
return false;
Expand All @@ -82,8 +85,7 @@ bool SandboxVectorizerPass::runImpl(Function &LLVMF) {
}

// Create SandboxIR for LLVMF and run BottomUpVec on it.
sandboxir::Context Ctx(LLVMF.getContext());
sandboxir::Function &F = *Ctx.createFunction(&LLVMF);
sandboxir::Function &F = *Ctx->createFunction(&LLVMF);
sandboxir::Analyses A(*AA, *SE);
return FPM.runOnFunction(F, A);
}
41 changes: 40 additions & 1 deletion llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,46 @@ define void @store_fpext_load(ptr %ptr) {
ret void
}

; TODO: Test store_zext_fcmp_load once we implement scheduler callbacks and legality diamond check
define void @store_fcmp_zext_load(ptr %ptr) {
; CHECK-LABEL: define void @store_fcmp_zext_load(
; CHECK-SAME: ptr [[PTR:%.*]]) {
; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
; CHECK-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1
; CHECK-NEXT: [[PTRB0:%.*]] = getelementptr i32, ptr [[PTR]], i32 0
; CHECK-NEXT: [[PTRB1:%.*]] = getelementptr i32, ptr [[PTR]], i32 1
; CHECK-NEXT: [[LDB0:%.*]] = load float, ptr [[PTR0]], align 4
; CHECK-NEXT: [[LDB1:%.*]] = load float, ptr [[PTR1]], align 4
; CHECK-NEXT: [[VECL1:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
; CHECK-NEXT: [[LDA0:%.*]] = load float, ptr [[PTR0]], align 4
; CHECK-NEXT: [[LDA1:%.*]] = load float, ptr [[PTR1]], align 4
; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
; CHECK-NEXT: [[FCMP0:%.*]] = fcmp ogt float [[LDA0]], [[LDB0]]
; CHECK-NEXT: [[FCMP1:%.*]] = fcmp ogt float [[LDA1]], [[LDB1]]
; CHECK-NEXT: [[VCMP:%.*]] = fcmp ogt <2 x float> [[VECL]], [[VECL1]]
; CHECK-NEXT: [[ZEXT0:%.*]] = zext i1 [[FCMP0]] to i32
; CHECK-NEXT: [[ZEXT1:%.*]] = zext i1 [[FCMP1]] to i32
; CHECK-NEXT: [[VCAST:%.*]] = zext <2 x i1> [[VCMP]] to <2 x i32>
; CHECK-NEXT: store i32 [[ZEXT0]], ptr [[PTRB0]], align 4
; CHECK-NEXT: store i32 [[ZEXT1]], ptr [[PTRB1]], align 4
; CHECK-NEXT: store <2 x i32> [[VCAST]], ptr [[PTRB0]], align 4
; CHECK-NEXT: ret void
;
%ptr0 = getelementptr float, ptr %ptr, i32 0
%ptr1 = getelementptr float, ptr %ptr, i32 1
%ptrb0 = getelementptr i32, ptr %ptr, i32 0
%ptrb1 = getelementptr i32, ptr %ptr, i32 1
%ldB0 = load float, ptr %ptr0
%ldB1 = load float, ptr %ptr1
%ldA0 = load float, ptr %ptr0
%ldA1 = load float, ptr %ptr1
%fcmp0 = fcmp ogt float %ldA0, %ldB0
%fcmp1 = fcmp ogt float %ldA1, %ldB1
%zext0 = zext i1 %fcmp0 to i32
%zext1 = zext i1 %fcmp1 to i32
store i32 %zext0, ptr %ptrb0
store i32 %zext1, ptr %ptrb1
ret void
}

; TODO: Test store_fadd_load once we implement scheduler callbacks and legality diamond check

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);

sandboxir::LegalityAnalysis Legality(*AA, *SE, DL);
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx);
const auto &Result =
Legality.canVectorize({St0, St1}, /*SkipScheduling=*/true);
EXPECT_TRUE(isa<sandboxir::Widen>(Result));
Expand Down Expand Up @@ -228,7 +228,7 @@ define void @foo(ptr %ptr) {
auto *St0 = cast<sandboxir::StoreInst>(&*It++);
auto *St1 = cast<sandboxir::StoreInst>(&*It++);

sandboxir::LegalityAnalysis Legality(*AA, *SE, DL);
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx);
{
// Can vectorize St0,St1.
const auto &Result = Legality.canVectorize({St0, St1});
Expand Down Expand Up @@ -262,7 +262,8 @@ define void @foo() {
return Buff == ExpectedStr;
};

sandboxir::LegalityAnalysis Legality(*AA, *SE, DL);
sandboxir::Context Ctx(C);
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx);
EXPECT_TRUE(
Matches(Legality.createLegalityResult<sandboxir::Widen>(), "Widen"));
EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,20 +156,20 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {

{
// Schedule all instructions in sequence.
sandboxir::Scheduler Sched(getAA(*LLVMF));
sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
EXPECT_TRUE(Sched.trySchedule({Ret}));
EXPECT_TRUE(Sched.trySchedule({S1}));
EXPECT_TRUE(Sched.trySchedule({S0}));
}
{
// Skip instructions.
sandboxir::Scheduler Sched(getAA(*LLVMF));
sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
EXPECT_TRUE(Sched.trySchedule({Ret}));
EXPECT_TRUE(Sched.trySchedule({S0}));
}
{
// Try invalid scheduling
sandboxir::Scheduler Sched(getAA(*LLVMF));
sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
EXPECT_TRUE(Sched.trySchedule({Ret}));
EXPECT_TRUE(Sched.trySchedule({S0}));
EXPECT_FALSE(Sched.trySchedule({S1}));
Expand Down Expand Up @@ -197,7 +197,7 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) {
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);

sandboxir::Scheduler Sched(getAA(*LLVMF));
sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
EXPECT_TRUE(Sched.trySchedule({Ret}));
EXPECT_TRUE(Sched.trySchedule({S0, S1}));
EXPECT_TRUE(Sched.trySchedule({L0, L1}));
Expand Down
Loading