Skip to content

Commit 93e6632

Browse files
[mlir][shape] Migrate bufferization to BufferizableOpInterface
Differential Revision: https://reviews.llvm.org/D121043
1 parent df6c26f commit 93e6632

File tree

7 files changed

+208
-98
lines changed

7 files changed

+208
-98
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_SHAPE_BUFFERIZABLEOPINTERFACEIMPL_H
10+
#define MLIR_DIALECT_SHAPE_BUFFERIZABLEOPINTERFACEIMPL_H
11+
12+
namespace mlir {
13+
class DialectRegistry;
14+
15+
namespace shape {
16+
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
17+
} // namespace shape
18+
} // namespace mlir
19+
20+
#endif // MLIR_DIALECT_SHAPE_BUFFERIZABLEOPINTERFACEIMPL_H

mlir/include/mlir/Dialect/Shape/Transforms/Passes.h

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,6 @@ void populateShapeRewritePatterns(RewritePatternSet &patterns);
4040
void populateRemoveShapeConstraintsPatterns(RewritePatternSet &patterns);
4141
std::unique_ptr<OperationPass<FuncOp>> createRemoveShapeConstraintsPass();
4242

43-
/// Populates patterns for shape dialect structural type conversions and sets up
44-
/// the provided ConversionTarget with the appropriate legality configuration
45-
/// for the ops to get converted properly.
46-
///
47-
/// A "structural" type conversion is one where the underlying ops are
48-
/// completely agnostic to the actual types involved and simply need to update
49-
/// their types consistently. An example of this is shape.assuming -- the
50-
/// shape.assuming op and the corresponding shape.assuming_yield op need to have
51-
/// consistent types, but the exact types don't matter. So all that we need to
52-
/// do for a structural type conversion is to update both of their types
53-
/// consistently to the new types prescribed by the TypeConverter.
54-
void populateShapeStructuralTypeConversionsAndLegality(
55-
TypeConverter &typeConverter, RewritePatternSet &patterns,
56-
ConversionTarget &target);
57-
5843
// Bufferizes shape dialect ops.
5944
//
6045
// Note that most shape dialect ops must be converted to std before
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
10+
11+
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13+
#include "mlir/Dialect/Shape/IR/Shape.h"
14+
#include "mlir/IR/Dialect.h"
15+
#include "mlir/IR/Operation.h"
16+
#include "mlir/IR/PatternMatch.h"
17+
18+
using namespace mlir;
19+
using namespace mlir::bufferization;
20+
using namespace mlir::shape;
21+
22+
namespace mlir {
23+
namespace shape {
24+
namespace {
25+
26+
/// Bufferization of shape.assuming.
27+
struct AssumingOpInterface
28+
: public BufferizableOpInterface::ExternalModel<AssumingOpInterface,
29+
shape::AssumingOp> {
30+
SmallVector<OpOperand *>
31+
getAliasingOpOperand(Operation *op, OpResult opResult,
32+
const BufferizationState &state) const {
33+
// AssumingOps do not have tensor OpOperands. The yielded value can be any
34+
// SSA value that is in scope. To allow for use-def chain traversal through
35+
// AssumingOps in the analysis, the corresponding yield value is considered
36+
// to be aliasing with the result.
37+
auto assumingOp = cast<shape::AssumingOp>(op);
38+
size_t resultNum = std::distance(op->getOpResults().begin(),
39+
llvm::find(op->getOpResults(), opResult));
40+
// TODO: Support multiple blocks.
41+
assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
42+
"expected exactly 1 block");
43+
auto yieldOp = dyn_cast<shape::AssumingYieldOp>(
44+
assumingOp.getDoRegion().front().getTerminator());
45+
assert(yieldOp && "expected shape.assuming_yield terminator");
46+
return {&yieldOp->getOpOperand(resultNum)};
47+
}
48+
49+
// TODO: For better bufferization results, this could return `true` only if
50+
// there is a memory write in the region.
51+
bool isMemoryWrite(Operation *op, OpResult opResult,
52+
const BufferizationState &state) const {
53+
// Similar to scf.if, results of this op are always considered memory writes
54+
// in the analysis. This is a useful pattern for all ops that have tensor
55+
// OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
56+
// implemented in terms of `bufferizesToMemoryWrite`, which does not work on
57+
// ops without OpOperands.
58+
return true;
59+
}
60+
61+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
62+
const BufferizationState &state) const {
63+
auto assumingOp = cast<shape::AssumingOp>(op);
64+
65+
// Compute new result types.
66+
SmallVector<Type> newResultTypes;
67+
for (Type type : assumingOp->getResultTypes()) {
68+
if (auto tensorType = type.dyn_cast<TensorType>()) {
69+
newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
70+
} else {
71+
newResultTypes.push_back(type);
72+
}
73+
}
74+
75+
// Create new op and move over region.
76+
auto newOp = rewriter.create<shape::AssumingOp>(
77+
op->getLoc(), newResultTypes, assumingOp.getWitness());
78+
newOp.getDoRegion().takeBody(assumingOp.getRegion());
79+
80+
// Update terminator.
81+
assert(newOp.getDoRegion().getBlocks().size() == 1 &&
82+
"only 1 block supported");
83+
Block *newBlock = &newOp.getDoRegion().front();
84+
auto yieldOp = cast<shape::AssumingYieldOp>(newBlock->getTerminator());
85+
rewriter.setInsertionPoint(yieldOp);
86+
SmallVector<Value> newYieldValues;
87+
for (const auto &it : llvm::enumerate(yieldOp.operands())) {
88+
Value val = it.value();
89+
if (val.getType().isa<TensorType>()) {
90+
newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>(
91+
yieldOp.getLoc(), newResultTypes[it.index()], val));
92+
} else {
93+
newYieldValues.push_back(val);
94+
}
95+
}
96+
rewriter.replaceOpWithNewOp<shape::AssumingYieldOp>(yieldOp,
97+
newYieldValues);
98+
99+
// Update all uses of the old op.
100+
rewriter.setInsertionPointAfter(newOp);
101+
SmallVector<Value> newResults;
102+
for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) {
103+
if (it.value().isa<TensorType>()) {
104+
newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
105+
assumingOp.getLoc(), newOp->getResult(it.index())));
106+
} else {
107+
newResults.push_back(newOp->getResult(it.index()));
108+
}
109+
}
110+
111+
// Replace old op.
112+
rewriter.replaceOp(assumingOp, newResults);
113+
114+
return success();
115+
}
116+
117+
BufferRelation bufferRelation(Operation *op, OpResult opResult,
118+
const BufferizationState &state) const {
119+
return BufferRelation::Equivalent;
120+
}
121+
};
122+
123+
/// Bufferization of shape.assuming_yield. Bufferized as part of their enclosing
124+
/// ops, so this is for analysis only.
125+
struct AssumingYieldOpInterface
126+
: public BufferizableOpInterface::ExternalModel<AssumingYieldOpInterface,
127+
shape::AssumingOp> {
128+
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
129+
const BufferizationState &state) const {
130+
return true;
131+
}
132+
133+
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
134+
const BufferizationState &state) const {
135+
return false;
136+
}
137+
138+
SmallVector<OpResult>
139+
getAliasingOpResult(Operation *op, OpOperand &opOperand,
140+
const BufferizationState &state) const {
141+
assert(isa<shape::AssumingOp>(op->getParentOp()) &&
142+
"expected that parent is an AssumingOp");
143+
return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
144+
}
145+
146+
bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
147+
const BufferizationState &state) const {
148+
// Yield operands always bufferize inplace. Otherwise, an alloc + copy
149+
// may be generated inside the block. We should not return/yield allocations
150+
// when possible.
151+
return true;
152+
}
153+
154+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
155+
const BufferizationState &state) const {
156+
// Op is bufferized as part of AssumingOp.
157+
return failure();
158+
}
159+
};
160+
161+
} // namespace
162+
} // namespace shape
163+
} // namespace mlir
164+
165+
void mlir::shape::registerBufferizableOpInterfaceExternalModels(
166+
DialectRegistry &registry) {
167+
registry.addOpInterface<shape::AssumingOp, AssumingOpInterface>();
168+
registry.addOpInterface<shape::AssumingYieldOp, AssumingYieldOpInterface>();
169+
}

mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,32 @@
88

99
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
1010
#include "PassDetail.h"
11+
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1112
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1213
#include "mlir/Dialect/MemRef/IR/MemRef.h"
14+
#include "mlir/Dialect/Shape/IR/Shape.h"
15+
#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
1316
#include "mlir/Dialect/Shape/Transforms/Passes.h"
1417
#include "mlir/Pass/Pass.h"
1518

1619
using namespace mlir;
20+
using namespace bufferization;
1721

1822
namespace {
1923
struct ShapeBufferizePass : public ShapeBufferizeBase<ShapeBufferizePass> {
2024
void runOnOperation() override {
21-
MLIRContext &ctx = getContext();
25+
BufferizationOptions options = getPartialBufferizationOptions();
26+
options.allowDialectInFilter<shape::ShapeDialect>();
2227

23-
RewritePatternSet patterns(&ctx);
24-
bufferization::BufferizeTypeConverter typeConverter;
25-
ConversionTarget target(ctx);
26-
27-
bufferization::populateBufferizeMaterializationLegality(target);
28-
populateShapeStructuralTypeConversionsAndLegality(typeConverter, patterns,
29-
target);
30-
31-
if (failed(applyPartialConversion(getOperation(), target,
32-
std::move(patterns))))
28+
if (failed(bufferizeOp(getOperation(), options)))
3329
signalPassFailure();
3430
}
31+
32+
void getDependentDialects(DialectRegistry &registry) const override {
33+
registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
34+
shape::ShapeDialect>();
35+
shape::registerBufferizableOpInterfaceExternalModels(registry);
36+
}
3537
};
3638
} // namespace
3739

mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
add_mlir_dialect_library(MLIRShapeOpsTransforms
2+
BufferizableOpInterfaceImpl.cpp
23
Bufferize.cpp
34
RemoveShapeConstraints.cpp
45
ShapeToShapeLowering.cpp
5-
StructuralTypeConversions.cpp
66

77
ADDITIONAL_HEADER_DIRS
88
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ShapeOps/Transforms
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRShapeOpsTransforms
1414
target_link_libraries(MLIRShapeOpsTransforms
1515
PUBLIC
1616
MLIRArithmetic
17+
MLIRBufferization
1718
MLIRBufferizationTransforms
1819
MLIRIR
1920
MLIRMemRef

mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp

Lines changed: 0 additions & 70 deletions
This file was deleted.

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2702,7 +2702,10 @@ cc_library(
27022702
"lib/Dialect/Shape/Transforms/*.cpp",
27032703
"lib/Dialect/Shape/Transforms/*.h",
27042704
]),
2705-
hdrs = ["include/mlir/Dialect/Shape/Transforms/Passes.h"],
2705+
hdrs = [
2706+
"include/mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h",
2707+
"include/mlir/Dialect/Shape/Transforms/Passes.h",
2708+
],
27062709
includes = ["include"],
27072710
deps = [
27082711
":ArithmeticDialect",

0 commit comments

Comments
 (0)