Skip to content

Commit 39848d0

Browse files
authored
Revert "[mlir] Remove dialect specific bufferization passes" (#93528)
Reverts #93488 Buildbot failure: https://lab.llvm.org/buildbot/#/builders/220/builds/39911
1 parent cbed9a6 commit 39848d0

File tree

32 files changed

+426
-10
lines changed

32 files changed

+426
-10
lines changed

mlir/include/mlir/Dialect/Arith/Transforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ namespace arith {
2424
class WideIntEmulationConverter;
2525
class NarrowTypeEmulationConverter;
2626

27+
/// Create a pass to bufferize arith.constant ops.
28+
std::unique_ptr<Pass> createConstantBufferizePass(uint64_t alignment = 0);
29+
2730
/// Adds patterns to emulate wide Arith and Function ops over integer
2831
/// types into supported ones. This is done by splitting original power-of-two
2932
/// i2N integer types into two iN halves.

mlir/include/mlir/Dialect/Arith/Transforms/Passes.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,22 @@
1111

1212
include "mlir/Pass/PassBase.td"
1313

14+
def ArithBufferizePass : Pass<"arith-bufferize", "ModuleOp"> {
15+
let summary = "Bufferize Arith dialect ops.";
16+
let description = [{
17+
This pass bufferizes arith dialect ops.
18+
19+
This pass needs to be a module pass because it inserts memref.global
20+
ops into the module, which cannot be done safely from a function pass due to
21+
multi-threading. Most other bufferization passes can run in parallel at
22+
function granularity.
23+
}];
24+
let options = [
25+
Option<"alignment", "alignment", "unsigned", /*default=*/"0",
26+
"Create global memrefs with a specified alignment">,
27+
];
28+
}
29+
1430
def ArithExpandOpsPass : Pass<"arith-expand"> {
1531
let summary = "Legalize Arith ops to be convertible to LLVM.";
1632
let dependentDialects = ["vector::VectorDialect"];

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,9 @@ createPromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc);
221221
/// insert_slice ops.
222222
std::unique_ptr<Pass> createEmptyTensorEliminationPass();
223223

224+
/// Create a pass that bufferizes ops from the bufferization dialect.
225+
std::unique_ptr<Pass> createBufferizationBufferizePass();
226+
224227
//===----------------------------------------------------------------------===//
225228
// Registration
226229
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,11 @@ def FinalizingBufferize : Pass<"finalizing-bufferize", "func::FuncOp"> {
350350
let constructor = "mlir::bufferization::createFinalizingBufferizePass()";
351351
}
352352

353+
def BufferizationBufferize : Pass<"bufferization-bufferize", "func::FuncOp"> {
354+
let summary = "Bufferize the `bufferization` dialect";
355+
let constructor = "mlir::bufferization::createBufferizationBufferizePass()";
356+
}
357+
353358
def DropEquivalentBufferResults : Pass<"drop-equivalent-buffer-results", "ModuleOp"> {
354359
let summary = "Remove MemRef return values that are equivalent to a bbArg";
355360
let description = [{

mlir/include/mlir/Dialect/Linalg/Passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ namespace func {
2222
class FuncOp;
2323
} // namespace func
2424

25+
namespace bufferization {
26+
struct OneShotBufferizationOptions;
27+
} // namespace bufferization
28+
2529
#define GEN_PASS_DECL
2630
#include "mlir/Dialect/Linalg/Passes.h.inc" // IWYU pragma: keep
2731

mlir/include/mlir/Dialect/Linalg/Passes.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,16 @@ def LinalgInlineScalarOperandsPass : Pass<"linalg-inline-scalar-operands"> {
8989
];
9090
}
9191

92+
def LinalgBufferizePass : Pass<"linalg-bufferize"> {
93+
let summary = "Bufferize the linalg dialect";
94+
let dependentDialects = [
95+
"affine::AffineDialect",
96+
"bufferization::BufferizationDialect",
97+
"linalg::LinalgDialect",
98+
"memref::MemRefDialect",
99+
];
100+
}
101+
92102
def LinalgGeneralizeNamedOpsPass : Pass<"linalg-generalize-named-ops"> {
93103
let summary = "Convert named ops into generic ops";
94104
let dependentDialects = ["linalg::LinalgDialect"];

mlir/include/mlir/Dialect/Shape/Transforms/Passes.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@ void populateShapeRewritePatterns(RewritePatternSet &patterns);
4747
void populateRemoveShapeConstraintsPatterns(RewritePatternSet &patterns);
4848
std::unique_ptr<OperationPass<func::FuncOp>> createRemoveShapeConstraintsPass();
4949

50+
// Bufferizes shape dialect ops.
51+
//
52+
// Note that most shape dialect ops must be converted to std before
53+
// bufferization happens, as they are intended to be bufferized at the std
54+
// level.
55+
std::unique_ptr<OperationPass<func::FuncOp>> createShapeBufferizePass();
56+
5057
/// Outline the shape computation part by adding shape.func and populate
5158
/// conrresponding mapping infomation into ShapeMappingAnalysis.
5259
std::unique_ptr<OperationPass<ModuleOp>> createOutlineShapeComputationPass();

mlir/include/mlir/Dialect/Shape/Transforms/Passes.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,4 +103,11 @@ def ShapeToShapeLowering : Pass<"shape-to-shape-lowering", "func::FuncOp"> {
103103
let constructor = "mlir::createShapeToShapeLowering()";
104104
}
105105

106+
// TODO: Generalize this to allow any type conversions desired.
107+
def ShapeBufferize : Pass<"shape-bufferize", "func::FuncOp"> {
108+
let summary = "Bufferize the shape dialect.";
109+
let constructor = "mlir::createShapeBufferizePass()";
110+
let dependentDialects = ["bufferization::BufferizationDialect",
111+
"memref::MemRefDialect"];
112+
}
106113
#endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES

mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ namespace tensor {
2121
/// Creates an instance of the `tensor` subset folding pass.
2222
std::unique_ptr<Pass> createFoldTensorSubsetOpsPass();
2323

24+
/// Creates an instance of the `tensor` dialect bufferization pass.
25+
std::unique_ptr<Pass> createTensorBufferizePass();
26+
2427
//===----------------------------------------------------------------------===//
2528
// Registration
2629
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Tensor/Transforms/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,9 @@ def FoldTensorSubsetOps : Pass<"fold-tensor-subset-ops"> {
2727
];
2828
}
2929

30+
def TensorBufferize : Pass<"tensor-bufferize", "func::FuncOp"> {
31+
let summary = "Bufferize the `tensor` dialect";
32+
let constructor = "mlir::tensor::createTensorBufferizePass()";
33+
}
34+
3035
#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES

mlir/include/mlir/Dialect/Vector/Transforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ namespace vector {
1717
#define GEN_PASS_DECL
1818
#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
1919

20+
/// Creates an instance of the `vector` dialect bufferization pass.
21+
std::unique_ptr<Pass> createVectorBufferizePass();
22+
2023
/// Creates an instance of the `vector.mask` lowering pass.
2124
std::unique_ptr<Pass> createLowerVectorMaskPass();
2225

mlir/include/mlir/Dialect/Vector/Transforms/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@
1111

1212
include "mlir/Pass/PassBase.td"
1313

14+
def VectorBufferize : Pass<"vector-bufferize", "func::FuncOp"> {
15+
let summary = "Bufferize Vector dialect ops";
16+
let constructor = "mlir::vector::createVectorBufferizePass()";
17+
}
18+
1419
def LowerVectorMaskPass : Pass<"lower-vector-mask", "func::FuncOp"> {
1520
let summary = "Lower 'vector.mask' operations";
1621
let constructor = "mlir::vector::createLowerVectorMaskPass()";
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
//===- Bufferize.cpp - Bufferization for Arith ops ---------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Arith/Transforms/Passes.h"
10+
11+
#include "mlir/Dialect/Arith/IR/Arith.h"
12+
#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
13+
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
14+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
15+
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
16+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
17+
18+
namespace mlir {
19+
namespace arith {
20+
#define GEN_PASS_DEF_ARITHBUFFERIZEPASS
21+
#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
22+
} // namespace arith
23+
} // namespace mlir
24+
25+
using namespace mlir;
26+
using namespace bufferization;
27+
28+
namespace {
29+
/// Pass to bufferize Arith ops.
30+
struct ArithBufferizePass
31+
: public arith::impl::ArithBufferizePassBase<ArithBufferizePass> {
32+
using ArithBufferizePassBase::ArithBufferizePassBase;
33+
34+
ArithBufferizePass(uint64_t alignment = 0, bool constantOpOnly = false)
35+
: constantOpOnly(constantOpOnly) {
36+
this->alignment = alignment;
37+
}
38+
39+
void runOnOperation() override {
40+
BufferizationOptions options = getPartialBufferizationOptions();
41+
if (constantOpOnly) {
42+
options.opFilter.allowOperation<arith::ConstantOp>();
43+
} else {
44+
options.opFilter.allowDialect<arith::ArithDialect>();
45+
}
46+
options.bufferAlignment = alignment;
47+
48+
if (failed(bufferizeOp(getOperation(), options)))
49+
signalPassFailure();
50+
}
51+
52+
void getDependentDialects(DialectRegistry &registry) const override {
53+
registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
54+
arith::ArithDialect>();
55+
arith::registerBufferizableOpInterfaceExternalModels(registry);
56+
}
57+
58+
private:
59+
bool constantOpOnly;
60+
};
61+
} // namespace
62+
63+
std::unique_ptr<Pass>
64+
mlir::arith::createConstantBufferizePass(uint64_t alignment) {
65+
return std::make_unique<ArithBufferizePass>(alignment,
66+
/*constantOpOnly=*/true);
67+
}

mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_mlir_dialect_library(MLIRArithTransforms
22
BufferDeallocationOpInterfaceImpl.cpp
33
BufferizableOpInterfaceImpl.cpp
4+
Bufferize.cpp
45
BufferViewFlowOpInterfaceImpl.cpp
56
EmulateUnsupportedFloats.cpp
67
EmulateWideInt.cpp

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,29 @@ struct OneShotBufferizePass
320320
};
321321
} // namespace
322322

323+
namespace {
324+
struct BufferizationBufferizePass
325+
: public bufferization::impl::BufferizationBufferizeBase<
326+
BufferizationBufferizePass> {
327+
void runOnOperation() override {
328+
BufferizationOptions options = getPartialBufferizationOptions();
329+
options.opFilter.allowDialect<BufferizationDialect>();
330+
331+
if (failed(bufferizeOp(getOperation(), options)))
332+
signalPassFailure();
333+
}
334+
335+
void getDependentDialects(DialectRegistry &registry) const override {
336+
registry
337+
.insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
338+
}
339+
};
340+
} // namespace
341+
342+
std::unique_ptr<Pass> mlir::bufferization::createBufferizationBufferizePass() {
343+
return std::make_unique<BufferizationBufferizePass>();
344+
}
345+
323346
std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass() {
324347
return std::make_unique<OneShotBufferizePass>();
325348
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
//===- Bufferize.cpp - Bufferization of linalg ops ------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Linalg/Passes.h"
10+
11+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
12+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13+
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
14+
#include "mlir/Dialect/Func/IR/FuncOps.h"
15+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
16+
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
17+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
19+
#include "mlir/IR/BuiltinDialect.h"
20+
#include "mlir/IR/Operation.h"
21+
#include "mlir/Pass/Pass.h"
22+
23+
namespace mlir {
24+
#define GEN_PASS_DEF_LINALGBUFFERIZEPASS
25+
#include "mlir/Dialect/Linalg/Passes.h.inc"
26+
} // namespace mlir
27+
28+
using namespace mlir;
29+
using namespace bufferization;
30+
31+
namespace {
32+
/// Converts Linalg operations that work on tensor-type operands or results to
33+
/// work on buffers.
34+
struct LinalgBufferizePass
35+
: public impl::LinalgBufferizePassBase<LinalgBufferizePass> {
36+
using impl::LinalgBufferizePassBase<
37+
LinalgBufferizePass>::LinalgBufferizePassBase;
38+
void runOnOperation() override {
39+
BufferizationOptions options = getPartialBufferizationOptions();
40+
options.opFilter.allowDialect<linalg::LinalgDialect>();
41+
42+
if (failed(bufferizeOp(getOperation(), options)))
43+
signalPassFailure();
44+
}
45+
46+
void getDependentDialects(DialectRegistry &registry) const override {
47+
registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
48+
tensor::TensorDialect, linalg::LinalgDialect>();
49+
linalg::registerBufferizableOpInterfaceExternalModels(registry);
50+
}
51+
};
52+
} // namespace

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
22
AllInterfaces.cpp
33
BubbleUpExtractSlice.cpp
44
BufferizableOpInterfaceImpl.cpp
5+
Bufferize.cpp
56
ConstantFold.cpp
67
ConvertToDestinationStyle.cpp
78
ConvertConv2DToImg2Col.cpp
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
//====----- Bufferize.cpp - Bufferization of shape ops ---------*- C++-*--===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Shape/Transforms/Passes.h"
10+
11+
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13+
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
14+
#include "mlir/Dialect/Func/IR/FuncOps.h"
15+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
16+
#include "mlir/Dialect/Shape/IR/Shape.h"
17+
#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
18+
#include "mlir/Pass/Pass.h"
19+
20+
namespace mlir {
21+
#define GEN_PASS_DEF_SHAPEBUFFERIZE
22+
#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
23+
} // namespace mlir
24+
25+
using namespace mlir;
26+
using namespace bufferization;
27+
28+
namespace {
29+
struct ShapeBufferizePass
30+
: public impl::ShapeBufferizeBase<ShapeBufferizePass> {
31+
void runOnOperation() override {
32+
BufferizationOptions options = getPartialBufferizationOptions();
33+
options.opFilter.allowDialect<shape::ShapeDialect>();
34+
35+
if (failed(bufferizeOp(getOperation(), options)))
36+
signalPassFailure();
37+
}
38+
39+
void getDependentDialects(DialectRegistry &registry) const override {
40+
registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
41+
shape::ShapeDialect>();
42+
shape::registerBufferizableOpInterfaceExternalModels(registry);
43+
}
44+
};
45+
} // namespace
46+
47+
std::unique_ptr<OperationPass<func::FuncOp>> mlir::createShapeBufferizePass() {
48+
return std::make_unique<ShapeBufferizePass>();
49+
}

mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_mlir_dialect_library(MLIRShapeOpsTransforms
22
BufferizableOpInterfaceImpl.cpp
3+
Bufferize.cpp
34
OutlineShapeComputation.cpp
45
RemoveShapeConstraints.cpp
56
ShapeToShapeLowering.cpp

0 commit comments

Comments
 (0)