Skip to content

[SandboxVec] Add pass to create Regions from metadata. Generalize SandboxVec pass pipelines. #112288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llvm/include/llvm/SandboxIR/Pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class Pass {
LLVM_DUMP_METHOD virtual void dump() const;
#endif
/// Similar to print() but adds a newline. Used for testing.
void printPipeline(raw_ostream &OS) const { OS << Name << "\n"; }
virtual void printPipeline(raw_ostream &OS) const { OS << Name << "\n"; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be guarde by #ifndef NDEBUG ?

Copy link
Collaborator Author

@slackito slackito Oct 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the code that implements the -sbvec-print-pass-pipeline flag. And it wasn't guarded before. Do we want the flag to only work in debug builds? If so, should we guard the flag as well? (also the tests that rely on printing the pass pipeline will become unsupported in release builds) I think those questions are out of scope for this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, it's because of the tests, makes sense.

};

/// A pass that runs on a sandbox::Function.
Expand Down
133 changes: 116 additions & 17 deletions llvm/include/llvm/SandboxIR/PassManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,20 @@ class Value;
/// Base class.
template <typename ParentPass, typename ContainedPass>
class PassManager : public ParentPass {
public:
// CreatePassFunc(StringRef PassName, StringRef PassArgs).
using CreatePassFunc =
std::function<std::unique_ptr<ContainedPass>(StringRef, StringRef)>;

protected:
/// The list of passes that this pass manager will run.
SmallVector<std::unique_ptr<ContainedPass>> Passes;

PassManager(StringRef Name) : ParentPass(Name) {}
PassManager(StringRef Name, StringRef Pipeline, CreatePassFunc CreatePass)
: ParentPass(Name) {
setPassPipeline(Pipeline, CreatePass);
}
PassManager(const PassManager &) = delete;
PassManager(PassManager &&) = default;
virtual ~PassManager() = default;
Expand All @@ -49,41 +58,125 @@ class PassManager : public ParentPass {
Passes.push_back(std::move(Pass));
}

using CreatePassFunc =
std::function<std::unique_ptr<ContainedPass>(StringRef)>;

/// Parses \p Pipeline as a comma-separated sequence of pass names and sets
/// the pass pipeline, using \p CreatePass to instantiate passes by name.
///
/// After calling this function, the PassManager contains only the specified
/// pipeline, any previously added passes are cleared.
/// Passes can have arguments, for example:
/// "pass1<arg1,arg2>,pass2,pass3<arg3,arg4>"
///
/// The arguments between angle brackets are treated as a mostly opaque string
/// and each pass is responsible for parsing its arguments. The exception to
/// this are nested angle brackets, which must match pair-wise to allow
/// arguments to contain nested pipelines, like:
///
/// "pass1<subpass1,subpass2<arg1,arg2>,subpass3>"
///
/// An empty args string is treated the same as no args, so "pass" and
/// "pass<>" are equivalent.
void setPassPipeline(StringRef Pipeline, CreatePassFunc CreatePass) {
static constexpr const char EndToken = '\0';
static constexpr const char BeginArgsToken = '<';
static constexpr const char EndArgsToken = '>';
static constexpr const char PassDelimToken = ',';

assert(Passes.empty() &&
"setPassPipeline called on a non-empty sandboxir::PassManager");

// Accept an empty pipeline as a special case. This can be useful, for
// example, to test conversion to SandboxIR without running any passes on
// it.
if (Pipeline.empty())
return;

// Add EndToken to the end to ease parsing.
std::string PipelineStr = std::string(Pipeline) + EndToken;
int FlagBeginIdx = 0;

for (auto [Idx, C] : enumerate(PipelineStr)) {
// Keep moving Idx until we find the end of the pass name.
bool FoundDelim = C == EndToken || C == PassDelimToken;
if (!FoundDelim)
continue;
unsigned Sz = Idx - FlagBeginIdx;
std::string PassName(&PipelineStr[FlagBeginIdx], Sz);
FlagBeginIdx = Idx + 1;
Pipeline = StringRef(PipelineStr);

auto AddPass = [this, CreatePass](StringRef PassName, StringRef PassArgs) {
if (PassName.empty()) {
errs() << "Found empty pass name.\n";
exit(1);
}
// Get the pass that corresponds to PassName and add it to the pass
// manager.
auto Pass = CreatePass(PassName);
auto Pass = CreatePass(PassName, PassArgs);
if (Pass == nullptr) {
errs() << "Pass '" << PassName << "' not registered!\n";
exit(1);
}
addPass(std::move(Pass));
};

enum class State {
ScanName, // reading a pass name
ScanArgs, // reading a list of args
ArgsEnded, // read the last '>' in an args list, must read delimiter next
} CurrentState = State::ScanName;
int PassBeginIdx = 0;
int ArgsBeginIdx;
StringRef PassName;
int NestedArgs = 0;
for (auto [Idx, C] : enumerate(Pipeline)) {
switch (CurrentState) {
case State::ScanName:
if (C == BeginArgsToken) {
// Save pass name for later and begin scanning args.
PassName = Pipeline.slice(PassBeginIdx, Idx);
ArgsBeginIdx = Idx + 1;
++NestedArgs;
CurrentState = State::ScanArgs;
break;
}
if (C == EndArgsToken) {
errs() << "Unexpected '>' in pass pipeline.\n";
exit(1);
}
if (C == EndToken || C == PassDelimToken) {
// Delimiter found, add the pass (with empty args), stay in the
// ScanName state.
AddPass(Pipeline.slice(PassBeginIdx, Idx), StringRef());
PassBeginIdx = Idx + 1;
}
break;
case State::ScanArgs:
// While scanning args, we only care about making sure nesting of angle
// brackets is correct.
if (C == BeginArgsToken) {
++NestedArgs;
break;
}
if (C == EndArgsToken) {
--NestedArgs;
if (NestedArgs == 0) {
// Done scanning args.
AddPass(PassName, Pipeline.slice(ArgsBeginIdx, Idx));
CurrentState = State::ArgsEnded;
} else if (NestedArgs < 0) {
errs() << "Unexpected '>' in pass pipeline.\n";
exit(1);
}
break;
}
if (C == EndToken) {
errs() << "Missing '>' in pass pipeline. End-of-string reached while "
"reading arguments for pass '"
<< PassName << "'.\n";
exit(1);
}
break;
case State::ArgsEnded:
// Once we're done scanning args, only a delimiter is valid. This avoids
// accepting strings like "foo<args><more-args>" or "foo<args>bar".
if (C == EndToken || C == PassDelimToken) {
PassBeginIdx = Idx + 1;
CurrentState = State::ScanName;
} else {
errs() << "Expected delimiter or end-of-string after pass "
"arguments.\n";
exit(1);
}
break;
}
}
}

Expand All @@ -101,7 +194,7 @@ class PassManager : public ParentPass {
}
#endif
/// Similar to print() but prints one pass per line. Used for testing.
void printPipeline(raw_ostream &OS) const {
void printPipeline(raw_ostream &OS) const override {
OS << this->getName() << "\n";
for (const auto &PassPtr : Passes)
PassPtr->printPipeline(OS);
Expand All @@ -112,12 +205,18 @@ class FunctionPassManager final
: public PassManager<FunctionPass, FunctionPass> {
public:
FunctionPassManager(StringRef Name) : PassManager(Name) {}
FunctionPassManager(StringRef Name, StringRef Pipeline,
CreatePassFunc CreatePass)
: PassManager(Name, Pipeline, CreatePass) {}
bool runOnFunction(Function &F) final;
};

class RegionPassManager final : public PassManager<RegionPass, RegionPass> {
public:
RegionPassManager(StringRef Name) : PassManager(Name) {}
RegionPassManager(StringRef Name, StringRef Pipeline,
CreatePassFunc CreatePass)
: PassManager(Name, Pipeline, CreatePass) {}
bool runOnRegion(Region &R) final;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_PASSES_BOTTOMUPVEC_H

#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/SandboxIR/Constant.h"
#include "llvm/SandboxIR/Pass.h"
#include "llvm/SandboxIR/PassManager.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"

namespace llvm::sandboxir {

class RegionPassManager;

class BottomUpVec final : public FunctionPass {
bool Change = false;
LegalityAnalysis Legality;
Expand All @@ -32,8 +32,12 @@ class BottomUpVec final : public FunctionPass {
RegionPassManager RPM;

public:
BottomUpVec();
BottomUpVec(StringRef Pipeline);
bool runOnFunction(Function &F) final;
void printPipeline(raw_ostream &OS) const final {
OS << getName() << "\n";
RPM.printPipeline(OS);
}
};

} // namespace llvm::sandboxir
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_PASSES_PRINTINSTRUCTIONCOUNT_H
#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_PASSES_PRINTINSTRUCTIONCOUNT_H

#include "llvm/SandboxIR/Pass.h"
#include "llvm/SandboxIR/Region.h"

namespace llvm::sandboxir {

/// A Region pass that prints the instruction count for the region to stdout.
/// Used to test -sbvec-passes while we don't have any actual optimization
/// passes.
class PrintInstructionCount final : public RegionPass {
public:
PrintInstructionCount() : RegionPass("null") {}
bool runOnRegion(Region &R) final {
outs() << "InstructionCount: " << std::distance(R.begin(), R.end()) << "\n";
return false;
}
};

} // namespace llvm::sandboxir

#endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_PASSES_PRINTINSTRUCTIONCOUNTPASS_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
//===- RegionsFromMetadata.h ------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// A SandboxIR function pass that builds regions from IR metadata and then runs
// a pipeline of region passes on them. This is useful to test region passes in
// isolation without relying on the output of the bottom-up vectorizer.
//

#ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_PASSES_REGIONSFROMMETADATA_H
#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_PASSES_REGIONSFROMMETADATA_H

#include "llvm/ADT/StringRef.h"
#include "llvm/SandboxIR/Pass.h"
#include "llvm/SandboxIR/PassManager.h"

namespace llvm::sandboxir {

class RegionsFromMetadata final : public FunctionPass {
// The PM containing the pipeline of region passes.
RegionPassManager RPM;

public:
RegionsFromMetadata(StringRef Pipeline);
bool runOnFunction(Function &F) final;
void printPipeline(raw_ostream &OS) const final {
OS << getName() << "\n";
RPM.printPipeline(OS);
}
};

} // namespace llvm::sandboxir

#endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_PASSES_REGIONSFROMMETADATA_H
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include <memory>

#include "llvm/IR/PassManager.h"
#include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h"
#include "llvm/SandboxIR/PassManager.h"

namespace llvm {

Expand All @@ -20,8 +20,8 @@ class TargetTransformInfo;
class SandboxVectorizerPass : public PassInfoMixin<SandboxVectorizerPass> {
TargetTransformInfo *TTI = nullptr;

// The main vectorizer pass.
sandboxir::BottomUpVec BottomUpVecPass;
// A pipeline of SandboxIR function passes run by the vectorizer.
sandboxir::FunctionPassManager FPM;

bool runImpl(Function &F);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
//===- SandboxVectorizerPassBuilder.h ---------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Utility functions so passes with sub-pipelines can create SandboxVectorizer
// passes without replicating the same logic in each pass.
//
#ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_SANDBOXVECTORIZERPASSBUILDER_H
#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_SANDBOXVECTORIZERPASSBUILDER_H

#include "llvm/ADT/StringRef.h"
#include "llvm/SandboxIR/Pass.h"

#include <memory>

namespace llvm::sandboxir {

class SandboxVectorizerPassBuilder {
public:
static std::unique_ptr<FunctionPass> createFunctionPass(StringRef Name,
StringRef Args);
static std::unique_ptr<RegionPass> createRegionPass(StringRef Name,
StringRef Args);
};

} // namespace llvm::sandboxir

#endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_SANDBOXVECTORIZERPASSBUILDER_H
2 changes: 2 additions & 0 deletions llvm/lib/Transforms/Vectorize/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ add_llvm_component_library(LLVMVectorize
SandboxVectorizer/DependencyGraph.cpp
SandboxVectorizer/Interval.cpp
SandboxVectorizer/Passes/BottomUpVec.cpp
SandboxVectorizer/Passes/RegionsFromMetadata.cpp
SandboxVectorizer/SandboxVectorizer.cpp
SandboxVectorizer/SandboxVectorizerPassBuilder.cpp
SandboxVectorizer/SeedCollector.cpp
SLPVectorizer.cpp
Vectorize.cpp
Expand Down
Loading
Loading