Skip to content

[MLIR][NFC] Retire let constructor for MemRef #134788

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 1 addition & 33 deletions mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,46 +37,14 @@ class VectorDialect;
} // namespace vector

namespace memref {

//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//

#define GEN_PASS_DECL
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"

/// Creates an instance of the ExpandOps pass that legalizes memref dialect ops
/// to be convertible to LLVM. For example, `memref.reshape` gets converted to
/// `memref_reinterpret_cast`.
std::unique_ptr<Pass> createExpandOpsPass();

/// Creates an operation pass to fold memref aliasing ops into consumer
/// load/store ops into `patterns`.
std::unique_ptr<Pass> createFoldMemRefAliasOpsPass();

/// Creates an interprocedural pass to normalize memrefs to have a trivial
/// (identity) layout map.
std::unique_ptr<OperationPass<ModuleOp>> createNormalizeMemRefsPass();

/// Creates an operation pass to resolve `memref.dim` operations with values
/// that are defined by operations that implement the
/// `ReifyRankedShapedTypeOpInterface`, in terms of shapes of its input
/// operands.
std::unique_ptr<Pass> createResolveRankedShapeTypeResultDimsPass();

/// Creates an operation pass to resolve `memref.dim` operations with values
/// that are defined by operations that implement the
/// `InferShapedTypeOpInterface` or the `ReifyRankedShapedTypeOpInterface`,
/// in terms of shapes of its input operands.
std::unique_ptr<Pass> createResolveShapedTypeResultDimsPass();

/// Creates an operation pass to expand some memref operation into
/// easier to reason about operations.
std::unique_ptr<Pass> createExpandStridedMetadataPass();

/// Creates an operation pass to expand `memref.realloc` operations into their
/// components.
std::unique_ptr<Pass> createExpandReallocPass(bool emitDeallocs = true);

//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
Expand Down
32 changes: 12 additions & 20 deletions mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,16 @@

include "mlir/Pass/PassBase.td"

def ExpandOps : Pass<"memref-expand"> {
def ExpandOpsPass : Pass<"memref-expand"> {
let summary = "Legalize memref operations to be convertible to LLVM.";
let constructor = "mlir::memref::createExpandOpsPass()";
}

def FoldMemRefAliasOps : Pass<"fold-memref-alias-ops"> {
def FoldMemRefAliasOpsPass : Pass<"fold-memref-alias-ops"> {
let summary = "Fold memref alias ops into consumer load/store ops";
let description = [{
The pass folds loading/storing from/to memref aliasing ops to loading/storing
from/to the original memref.
}];
let constructor = "mlir::memref::createFoldMemRefAliasOpsPass()";
let dependentDialects = [
"affine::AffineDialect", "memref::MemRefDialect", "vector::VectorDialect"
];
Expand All @@ -44,9 +42,9 @@ def MemRefEmulateWideInt : Pass<"memref-emulate-wide-int"> {
let dependentDialects = ["vector::VectorDialect"];
}

def NormalizeMemRefs : Pass<"normalize-memrefs", "ModuleOp"> {
def NormalizeMemRefsPass : Pass<"normalize-memrefs", "ModuleOp"> {
let summary = "Normalize memrefs";
let description = [{
let description = [{
This pass transforms memref types with a non-trivial
[layout map](https://mlir.llvm.org/docs/Dialects/Builtin/#affine-map-layout)
into memref types with an identity layout map, e.g. (i, j) -> (i, j). This
Expand Down Expand Up @@ -155,40 +153,36 @@ def NormalizeMemRefs : Pass<"normalize-memrefs", "ModuleOp"> {
}
```
}];
let constructor = "mlir::memref::createNormalizeMemRefsPass()";
let dependentDialects = ["affine::AffineDialect"];
}

def ResolveRankedShapeTypeResultDims :
Pass<"resolve-ranked-shaped-type-result-dims"> {
def ResolveRankedShapeTypeResultDimsPass
: Pass<"resolve-ranked-shaped-type-result-dims"> {
let summary = "Resolve memref.dim of result values of ranked shape type";
let description = [{
The pass resolves memref.dim of result of operations that
implement the `ReifyRankedShapedTypeOpInterface` in terms of
shapes of its operands.
}];
let constructor =
"mlir::memref::createResolveRankedShapeTypeResultDimsPass()";
let dependentDialects = [
"memref::MemRefDialect", "tensor::TensorDialect"
];
}

def ResolveShapedTypeResultDims : Pass<"resolve-shaped-type-result-dims"> {
def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> {
let summary = "Resolve memref.dim of result values";
let description = [{
The pass resolves memref.dim of result of operations that
implement the `InferShapedTypeOpInterface` or
`ReifyRankedShapedTypeOpInterface` in terms of shapes of its
operands.
}];
let constructor = "mlir::memref::createResolveShapedTypeResultDimsPass()";
let dependentDialects = [
"affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect"
];
}

def ExpandStridedMetadata : Pass<"expand-strided-metadata"> {
def ExpandStridedMetadataPass : Pass<"expand-strided-metadata"> {
let summary = "Expand memref operations into easier to analyze constructs";
let description = [{
The pass expands memref operations that modify the metadata of a memref
Expand All @@ -205,13 +199,12 @@ def ExpandStridedMetadata : Pass<"expand-strided-metadata"> {
- `memref.extract_strided_metadata`
- `memref.subview`
}];
let constructor = "mlir::memref::createExpandStridedMetadataPass()";
let dependentDialects = [
"affine::AffineDialect", "memref::MemRefDialect"
];
}

def ExpandRealloc : Pass<"expand-realloc"> {
def ExpandReallocPass : Pass<"expand-realloc"> {
let summary = "Expand memref.realloc operations into its components";
let description = [{
The `memref.realloc` operation performs a conditional allocation and copy to
Expand Down Expand Up @@ -243,11 +236,10 @@ def ExpandRealloc : Pass<"expand-realloc"> {
}
```
}];
let options = [
Option<"emitDeallocs", "emit-deallocs", "bool", /*default=*/"true",
"Emit deallocation operations for the original MemRef">,
let options = [Option<"emitDeallocs", "emit-deallocs", "bool",
/*default=*/"true",
"Emit deallocation operations for the original MemRef">,
];
let constructor = "mlir::memref::createExpandReallocPass()";
let dependentDialects = [
"arith::ArithDialect", "scf::SCFDialect", "memref::MemRefDialect"
];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@

void mlir::bufferization::buildBufferDeallocationPipeline(
OpPassManager &pm, const BufferDeallocationPipelineOptions &options) {
pm.addPass(memref::createExpandReallocPass(/*emitDeallocs=*/false));
memref::ExpandReallocPassOptions expandAllocPassOptions{
/*emitDeallocs=*/false};
pm.addPass(memref::createExpandReallocPass(expandAllocPassOptions));
pm.addPass(createCanonicalizerPass());

OwnershipBasedBufferDeallocationPassOptions deallocationOptions{
Expand Down
8 changes: 2 additions & 6 deletions mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

namespace mlir {
namespace memref {
#define GEN_PASS_DEF_EXPANDOPS
#define GEN_PASS_DEF_EXPANDOPSPASS
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
} // namespace memref
} // namespace mlir
Expand Down Expand Up @@ -130,7 +130,7 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
}
};

struct ExpandOpsPass : public memref::impl::ExpandOpsBase<ExpandOpsPass> {
struct ExpandOpsPass : public memref::impl::ExpandOpsPassBase<ExpandOpsPass> {
void runOnOperation() override {
MLIRContext &ctx = getContext();

Expand Down Expand Up @@ -161,7 +161,3 @@ void mlir::memref::populateExpandOpsPatterns(RewritePatternSet &patterns) {
patterns.add<AtomicRMWOpConverter, MemRefReshapeOpConverter>(
patterns.getContext());
}

std::unique_ptr<Pass> mlir::memref::createExpandOpsPass() {
return std::make_unique<ExpandOpsPass>();
}
14 changes: 4 additions & 10 deletions mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

namespace mlir {
namespace memref {
#define GEN_PASS_DEF_EXPANDREALLOC
#define GEN_PASS_DEF_EXPANDREALLOCPASS
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
} // namespace memref
} // namespace mlir
Expand Down Expand Up @@ -142,11 +142,9 @@ struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> {
};

struct ExpandReallocPass
: public memref::impl::ExpandReallocBase<ExpandReallocPass> {
ExpandReallocPass(bool emitDeallocs)
: memref::impl::ExpandReallocBase<ExpandReallocPass>() {
this->emitDeallocs.setValue(emitDeallocs);
}
: public memref::impl::ExpandReallocPassBase<ExpandReallocPass> {
using Base::Base;

void runOnOperation() override {
MLIRContext &ctx = getContext();

Expand All @@ -169,7 +167,3 @@ void mlir::memref::populateExpandReallocPatterns(RewritePatternSet &patterns,
bool emitDeallocs) {
patterns.add<ExpandReallocOpPattern>(patterns.getContext(), emitDeallocs);
}

std::unique_ptr<Pass> mlir::memref::createExpandReallocPass(bool emitDeallocs) {
return std::make_unique<ExpandReallocPass>(emitDeallocs);
}
8 changes: 2 additions & 6 deletions mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

namespace mlir {
namespace memref {
#define GEN_PASS_DEF_EXPANDSTRIDEDMETADATA
#define GEN_PASS_DEF_EXPANDSTRIDEDMETADATAPASS
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
} // namespace memref
} // namespace mlir
Expand Down Expand Up @@ -1213,7 +1213,7 @@ void memref::populateResolveExtractStridedMetadataPatterns(
namespace {

struct ExpandStridedMetadataPass final
: public memref::impl::ExpandStridedMetadataBase<
: public memref::impl::ExpandStridedMetadataPassBase<
ExpandStridedMetadataPass> {
void runOnOperation() override;
};
Expand All @@ -1225,7 +1225,3 @@ void ExpandStridedMetadataPass::runOnOperation() {
memref::populateExpandStridedMetadataPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}

std::unique_ptr<Pass> memref::createExpandStridedMetadataPass() {
return std::make_unique<ExpandStridedMetadataPass>();
}
8 changes: 2 additions & 6 deletions mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

namespace mlir {
namespace memref {
#define GEN_PASS_DEF_FOLDMEMREFALIASOPS
#define GEN_PASS_DEF_FOLDMEMREFALIASOPSPASS
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
} // namespace memref
} // namespace mlir
Expand Down Expand Up @@ -848,7 +848,7 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
namespace {

struct FoldMemRefAliasOpsPass final
: public memref::impl::FoldMemRefAliasOpsBase<FoldMemRefAliasOpsPass> {
: public memref::impl::FoldMemRefAliasOpsPassBase<FoldMemRefAliasOpsPass> {
void runOnOperation() override;
};

Expand All @@ -859,7 +859,3 @@ void FoldMemRefAliasOpsPass::runOnOperation() {
memref::populateFoldMemRefAliasOpPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}

std::unique_ptr<Pass> memref::createFoldMemRefAliasOpsPass() {
return std::make_unique<FoldMemRefAliasOpsPass>();
}
9 changes: 2 additions & 7 deletions mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

namespace mlir {
namespace memref {
#define GEN_PASS_DEF_NORMALIZEMEMREFS
#define GEN_PASS_DEF_NORMALIZEMEMREFSPASS
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
} // namespace memref
} // namespace mlir
Expand All @@ -40,7 +40,7 @@ namespace {
/// to call a non-normalizable function, we treat that function as
/// non-normalizable as well. We assume external functions to be normalizable.
struct NormalizeMemRefs
: public memref::impl::NormalizeMemRefsBase<NormalizeMemRefs> {
: public memref::impl::NormalizeMemRefsPassBase<NormalizeMemRefs> {
void runOnOperation() override;
void normalizeFuncOpMemRefs(func::FuncOp funcOp, ModuleOp moduleOp);
bool areMemRefsNormalizable(func::FuncOp funcOp);
Expand All @@ -53,11 +53,6 @@ struct NormalizeMemRefs

} // namespace

std::unique_ptr<OperationPass<ModuleOp>>
mlir::memref::createNormalizeMemRefsPass() {
return std::make_unique<NormalizeMemRefs>();
}

void NormalizeMemRefs::runOnOperation() {
LLVM_DEBUG(llvm::dbgs() << "Normalizing Memrefs...\n");
ModuleOp moduleOp = getOperation();
Expand Down
16 changes: 4 additions & 12 deletions mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@

namespace mlir {
namespace memref {
#define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMS
#define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMS
#define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMSPASS
#define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMSPASS
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
} // namespace memref
} // namespace mlir
Expand Down Expand Up @@ -164,13 +164,13 @@ struct IterArgsToInitArgs : public OpRewritePattern<tensor::DimOp> {

namespace {
struct ResolveRankedShapeTypeResultDimsPass final
: public memref::impl::ResolveRankedShapeTypeResultDimsBase<
: public memref::impl::ResolveRankedShapeTypeResultDimsPassBase<
ResolveRankedShapeTypeResultDimsPass> {
void runOnOperation() override;
};

struct ResolveShapedTypeResultDimsPass final
: public memref::impl::ResolveShapedTypeResultDimsBase<
: public memref::impl::ResolveShapedTypeResultDimsPassBase<
ResolveShapedTypeResultDimsPass> {
void runOnOperation() override;
};
Expand Down Expand Up @@ -206,11 +206,3 @@ void ResolveShapedTypeResultDimsPass::runOnOperation() {
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}

std::unique_ptr<Pass> memref::createResolveShapedTypeResultDimsPass() {
return std::make_unique<ResolveShapedTypeResultDimsPass>();
}

std::unique_ptr<Pass> memref::createResolveRankedShapeTypeResultDimsPass() {
return std::make_unique<ResolveRankedShapeTypeResultDimsPass>();
}
2 changes: 1 addition & 1 deletion mlir/test/python/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def testInvalidNesting():
try:
pm = PassManager.parse("func.func(normalize-memrefs)")
except ValueError as e:
# CHECK: ValueError exception: Can't add pass 'NormalizeMemRefs' restricted to 'builtin.module' on a PassManager intended to run on 'func.func', did you intend to nest?
# CHECK: ValueError exception: Can't add pass 'NormalizeMemRefsPass' restricted to 'builtin.module' on a PassManager intended to run on 'func.func', did you intend to nest?
log("ValueError exception:", e)
else:
log("Exception not produced")
Expand Down
Loading