Skip to content

[mlir][linalg] NFC: Use tablegen macro for pass constructors #82892

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 1 addition & 37 deletions mlir/include/mlir/Dialect/Linalg/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,43 +27,7 @@ struct OneShotBufferizationOptions;
} // namespace bufferization

#define GEN_PASS_DECL
#include "mlir/Dialect/Linalg/Passes.h.inc"

std::unique_ptr<Pass> createConvertElementwiseToLinalgPass();

std::unique_ptr<Pass> createLinalgFoldUnitExtentDimsPass();

std::unique_ptr<Pass> createLinalgElementwiseOpFusionPass();
std::unique_ptr<Pass> createFoldReshapeOpsByLinearizationPass();

std::unique_ptr<Pass> createLinalgNamedOpConversionPass();

std::unique_ptr<Pass> createLinalgInlineScalarOperandsPass();

/// Create a pass to convert Linalg operations to scf.for loops and
/// memref.load/memref.store accesses.
std::unique_ptr<Pass> createConvertLinalgToLoopsPass();

/// Create a pass to convert Linalg operations to scf.parallel loops and
/// memref.load/memref.store accesses.
std::unique_ptr<Pass> createConvertLinalgToParallelLoopsPass();

/// Create a pass to convert Linalg operations to affine.for loops and
/// affine_load/affine_store accesses.
/// Placeholder for now, this is NYI.
std::unique_ptr<Pass> createConvertLinalgToAffineLoopsPass();

/// Create a pass to convert Linalg operations which work on tensors to use
/// buffers instead.
std::unique_ptr<Pass> createLinalgBufferizePass();

/// Create a pass to convert named Linalg operations to Linalg generic
/// operations.
std::unique_ptr<Pass> createLinalgGeneralizationPass();

/// Create a pass to convert Linalg operations to equivalent operations that
/// work on primitive types, if possible.
std::unique_ptr<Pass> createLinalgDetensorizePass();
#include "mlir/Dialect/Linalg/Passes.h.inc" // IWYU pragma: keep

//===----------------------------------------------------------------------===//
// Registration
Expand Down
87 changes: 38 additions & 49 deletions mlir/include/mlir/Dialect/Linalg/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

include "mlir/Pass/PassBase.td"

def ConvertElementwiseToLinalg : Pass<"convert-elementwise-to-linalg", ""> {
def ConvertElementwiseToLinalgPass : Pass<"convert-elementwise-to-linalg", ""> {
let summary = "Convert ElementwiseMappable ops to linalg";
let description = [{
Convert ops with the `ElementwiseMappable` trait to linalg parallel loops.
Expand All @@ -20,54 +20,17 @@ def ConvertElementwiseToLinalg : Pass<"convert-elementwise-to-linalg", ""> {
run on op which contains linalg ops (most commonly a
FunctionOpInterface op).
}];
let constructor = "mlir::createConvertElementwiseToLinalgPass()";
let dependentDialects = ["linalg::LinalgDialect", "memref::MemRefDialect"];
}

def LinalgFoldUnitExtentDims : Pass<"linalg-fold-unit-extent-dims", ""> {
let summary = "Remove unit-extent dimension in Linalg ops on tensors";
let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()";
let options = [
Option<"useRankReducingSlices", "use-rank-reducing-slices", "bool",
/*default=*/"false",
"Generate rank-reducing slices instead of reassociative reshapes">
];
let dependentDialects = [
"linalg::LinalgDialect", "affine::AffineDialect", "memref::MemRefDialect"
];
}

def LinalgElementwiseOpFusion : Pass<"linalg-fuse-elementwise-ops"> {
let summary = "Fuse elementwise operations on tensors";
let constructor = "mlir::createLinalgElementwiseOpFusionPass()";
let dependentDialects = [
"affine::AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"
];
}

def LinalgNamedOpConversion: Pass<"linalg-named-op-conversion"> {
let summary = "Convert from one named linalg op to another.";
let constructor = "mlir::createLinalgNamedOpConversionPass()";
let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
}

def LinalgInlineScalarOperands : Pass<"linalg-inline-scalar-operands"> {
let summary = "Inline scalar operands into linalg generic ops";
let constructor = "mlir::createLinalgInlineScalarOperandsPass()";
let dependentDialects = [
"linalg::LinalgDialect"
];
}

def LinalgLowerToAffineLoops : Pass<"convert-linalg-to-affine-loops"> {
def ConvertLinalgToAffineLoopsPass : Pass<"convert-linalg-to-affine-loops"> {
let summary = "Lower the operations from the linalg dialect into affine "
"loops";
let constructor = "mlir::createConvertLinalgToAffineLoopsPass()";
let dependentDialects = [
"affine::AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"];
}

def LinalgLowerToLoops : Pass<"convert-linalg-to-loops"> {
def ConvertLinalgToLoopsPass : Pass<"convert-linalg-to-loops"> {
let summary = "Lower the operations from the linalg dialect into loops";
let description = [{
Lowers the `linalg` ops to loop nests using `scf.for`.
Expand All @@ -76,19 +39,17 @@ def LinalgLowerToLoops : Pass<"convert-linalg-to-loops"> {
i.e., tensor operands and results must be converted to memrefs via
bufferization.
}];
let constructor = "mlir::createConvertLinalgToLoopsPass()";
let dependentDialects = [
"linalg::LinalgDialect",
"scf::SCFDialect",
"affine::AffineDialect"
];
}

def LinalgLowerToParallelLoops
def ConvertLinalgToParallelLoopsPass
: Pass<"convert-linalg-to-parallel-loops"> {
let summary = "Lower the operations from the linalg dialect into parallel "
"loops";
let constructor = "mlir::createConvertLinalgToParallelLoopsPass()";
let dependentDialects = [
"affine::AffineDialect",
"linalg::LinalgDialect",
Expand All @@ -97,9 +58,39 @@ def LinalgLowerToParallelLoops
];
}

def LinalgBufferize : Pass<"linalg-bufferize"> {
def LinalgFoldUnitExtentDimsPass : Pass<"linalg-fold-unit-extent-dims", ""> {
let summary = "Remove unit-extent dimension in Linalg ops on tensors";
let options = [
Option<"useRankReducingSlices", "use-rank-reducing-slices", "bool",
/*default=*/"false",
"Generate rank-reducing slices instead of reassociative reshapes">
];
let dependentDialects = [
"linalg::LinalgDialect", "affine::AffineDialect", "memref::MemRefDialect"
];
}

def LinalgElementwiseOpFusionPass : Pass<"linalg-fuse-elementwise-ops"> {
let summary = "Fuse elementwise operations on tensors";
let dependentDialects = [
"affine::AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"
];
}

def LinalgNamedOpConversionPass: Pass<"linalg-named-op-conversion"> {
let summary = "Convert from one named linalg op to another.";
let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
}

def LinalgInlineScalarOperandsPass : Pass<"linalg-inline-scalar-operands"> {
let summary = "Inline scalar operands into linalg generic ops";
let dependentDialects = [
"linalg::LinalgDialect"
];
}

def LinalgBufferizePass : Pass<"linalg-bufferize"> {
let summary = "Bufferize the linalg dialect";
let constructor = "mlir::createLinalgBufferizePass()";
let dependentDialects = [
"affine::AffineDialect",
"bufferization::BufferizationDialect",
Expand All @@ -108,15 +99,13 @@ def LinalgBufferize : Pass<"linalg-bufferize"> {
];
}

def LinalgGeneralization : Pass<"linalg-generalize-named-ops"> {
def LinalgGeneralizeNamedOpsPass : Pass<"linalg-generalize-named-ops"> {
let summary = "Convert named ops into generic ops";
let constructor = "mlir::createLinalgGeneralizationPass()";
let dependentDialects = ["linalg::LinalgDialect"];
}

def LinalgDetensorize : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
let summary = "Detensorize linalg ops";
let constructor = "mlir::createLinalgDetensorizePass()";
let dependentDialects = [];

let description = [{
Expand Down
10 changes: 4 additions & 6 deletions mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include "mlir/Pass/Pass.h"

namespace mlir {
#define GEN_PASS_DEF_LINALGBUFFERIZE
#define GEN_PASS_DEF_LINALGBUFFERIZEPASS
#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir

Expand All @@ -32,7 +32,9 @@ namespace {
/// Converts Linalg operations that work on tensor-type operands or results to
/// work on buffers.
struct LinalgBufferizePass
: public impl::LinalgBufferizeBase<LinalgBufferizePass> {
: public impl::LinalgBufferizePassBase<LinalgBufferizePass> {
using impl::LinalgBufferizePassBase<
LinalgBufferizePass>::LinalgBufferizePassBase;
void runOnOperation() override {
BufferizationOptions options = getPartialBufferizationOptions();
options.opFilter.allowDialect<linalg::LinalgDialect>();
Expand All @@ -48,7 +50,3 @@ struct LinalgBufferizePass
}
};
} // namespace

std::unique_ptr<Pass> mlir::createLinalgBufferizePass() {
return std::make_unique<LinalgBufferizePass>();
}
10 changes: 4 additions & 6 deletions mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include <utility>

namespace mlir {
#define GEN_PASS_DEF_LINALGDETENSORIZE
#define GEN_PASS_DEF_LINALGDETENSORIZEPASS
#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir

Expand Down Expand Up @@ -164,7 +164,9 @@ class DetensorizeTypeConverter : public TypeConverter {

/// @see LinalgDetensorize in Linalg/Passes.td for more details.
struct LinalgDetensorize
: public impl::LinalgDetensorizeBase<LinalgDetensorize> {
: public impl::LinalgDetensorizePassBase<LinalgDetensorize> {
using impl::LinalgDetensorizePassBase<
LinalgDetensorize>::LinalgDetensorizePassBase;
LinalgDetensorize() = default;

class CostModel {
Expand Down Expand Up @@ -576,7 +578,3 @@ struct LinalgDetensorize
}
};
} // namespace

std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() {
return std::make_unique<LinalgDetensorize>();
}
11 changes: 5 additions & 6 deletions mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#include "llvm/Support/Debug.h"

namespace mlir {
#define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMS
#define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMSPASS
#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir

Expand Down Expand Up @@ -689,7 +689,10 @@ void mlir::linalg::populateMoveInitOperandsToInputPattern(
namespace {
/// Pass that removes unit-extent dims within generic ops.
struct LinalgFoldUnitExtentDimsPass
: public impl::LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> {
: public impl::LinalgFoldUnitExtentDimsPassBase<
LinalgFoldUnitExtentDimsPass> {
using impl::LinalgFoldUnitExtentDimsPassBase<
LinalgFoldUnitExtentDimsPass>::LinalgFoldUnitExtentDimsPassBase;
void runOnOperation() override {
Operation *op = getOperation();
MLIRContext *context = op->getContext();
Expand All @@ -705,7 +708,3 @@ struct LinalgFoldUnitExtentDimsPass
}
};
} // namespace

std::unique_ptr<Pass> mlir::createLinalgFoldUnitExtentDimsPass() {
return std::make_unique<LinalgFoldUnitExtentDimsPass>();
}
11 changes: 4 additions & 7 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
#include <utility>

namespace mlir {
#define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMS
#define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSION
#define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS
#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir

Expand Down Expand Up @@ -1927,8 +1926,10 @@ namespace {
// favor of test passes that check the functionality of each of the patterns
// added here individually.
struct LinalgElementwiseOpFusionPass
: public impl::LinalgElementwiseOpFusionBase<
: public impl::LinalgElementwiseOpFusionPassBase<
LinalgElementwiseOpFusionPass> {
using impl::LinalgElementwiseOpFusionPassBase<
LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase;
void runOnOperation() override {
Operation *op = getOperation();
MLIRContext *context = op->getContext();
Expand Down Expand Up @@ -1963,7 +1964,3 @@ struct LinalgElementwiseOpFusionPass
};

} // namespace

std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() {
return std::make_unique<LinalgElementwiseOpFusionPass>();
}
10 changes: 4 additions & 6 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
#define GEN_PASS_DEF_CONVERTELEMENTWISETOLINALG
#define GEN_PASS_DEF_CONVERTELEMENTWISETOLINALGPASS
#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir

Expand Down Expand Up @@ -121,8 +121,10 @@ void mlir::linalg::populateElementwiseToLinalgConversionPatterns(

namespace {
class ConvertElementwiseToLinalgPass
: public impl::ConvertElementwiseToLinalgBase<
: public impl::ConvertElementwiseToLinalgPassBase<
ConvertElementwiseToLinalgPass> {
using impl::ConvertElementwiseToLinalgPassBase<
ConvertElementwiseToLinalgPass>::ConvertElementwiseToLinalgPassBase;

void runOnOperation() final {
auto *func = getOperation();
Expand All @@ -140,7 +142,3 @@ class ConvertElementwiseToLinalgPass
}
};
} // namespace

std::unique_ptr<Pass> mlir::createConvertElementwiseToLinalgPass() {
return std::make_unique<ConvertElementwiseToLinalgPass>();
}
15 changes: 7 additions & 8 deletions mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#include "llvm/Support/Debug.h"

namespace mlir {
#define GEN_PASS_DEF_LINALGGENERALIZATION
#define GEN_PASS_DEF_LINALGGENERALIZENAMEDOPSPASS
#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir

Expand Down Expand Up @@ -76,14 +76,17 @@ FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter,

namespace {

struct LinalgGeneralizationPass
: public impl::LinalgGeneralizationBase<LinalgGeneralizationPass> {
struct LinalgGeneralizeNamedOpsPass
: public impl::LinalgGeneralizeNamedOpsPassBase<
LinalgGeneralizeNamedOpsPass> {
using impl::LinalgGeneralizeNamedOpsPassBase<
LinalgGeneralizeNamedOpsPass>::LinalgGeneralizeNamedOpsPassBase;
void runOnOperation() override;
};

} // namespace

void LinalgGeneralizationPass::runOnOperation() {
void LinalgGeneralizeNamedOpsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateLinalgNamedOpsGeneralizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
Expand All @@ -93,7 +96,3 @@ void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
RewritePatternSet &patterns) {
patterns.add<LinalgGeneralizationPattern>(patterns.getContext());
}

std::unique_ptr<Pass> mlir::createLinalgGeneralizationPass() {
return std::make_unique<LinalgGeneralizationPass>();
}
Loading