Skip to content

Commit b4a7fa3

Browse files
committed
Add BufferizationState as argument to bufferize method
1 parent d1d837c commit b4a7fa3

25 files changed

+175
-86
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
426426
/*retType=*/"::llvm::LogicalResult",
427427
/*methodName=*/"bufferize",
428428
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
429-
"const ::mlir::bufferization::BufferizationOptions &":$options),
429+
"const ::mlir::bufferization::BufferizationOptions &":$options,
430+
"::mlir::bufferization::BufferizationState &":$state),
430431
/*methodBody=*/"",
431432
/*defaultImplementation=*/[{
432433
llvm_unreachable("bufferize not implemented");

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
9393

9494
let extraClassDeclaration = [{
9595
LogicalResult bufferize(RewriterBase &rewriter,
96-
const BufferizationOptions &options);
96+
const BufferizationOptions &options,
97+
BufferizationState &state);
9798

9899
bool resultBufferizesToMemoryWrite(OpResult opResult,
99100
const AnalysisState &state);
@@ -282,7 +283,8 @@ def Bufferization_MaterializeInDestinationOp
282283

283284
let extraClassDeclaration = [{
284285
LogicalResult bufferize(RewriterBase &rewriter,
285-
const BufferizationOptions &options);
286+
const BufferizationOptions &options,
287+
BufferizationState &state);
286288

287289
bool bufferizesToMemoryRead(OpOperand &opOperand,
288290
const AnalysisState &state);
@@ -375,7 +377,8 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
375377
}
376378

377379
LogicalResult bufferize(RewriterBase &rewriter,
378-
const BufferizationOptions &options);
380+
const BufferizationOptions &options,
381+
BufferizationState &state);
379382
}];
380383
}
381384

@@ -458,7 +461,8 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
458461
//===------------------------------------------------------------------===//
459462

460463
LogicalResult bufferize(RewriterBase &rewriter,
461-
const BufferizationOptions &options) const {
464+
const BufferizationOptions &options,
465+
BufferizationState &state) const {
462466
// to_tensor/to_memref pairs fold away after bufferization.
463467
return success();
464468
}
@@ -550,7 +554,8 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
550554
}
551555

552556
LogicalResult bufferize(RewriterBase &rewriter,
553-
const BufferizationOptions &options);
557+
const BufferizationOptions &options,
558+
BufferizationState &state);
554559
}];
555560

556561
let assemblyFormat = [{

mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ class BufferPlacementTransformationBase {
122122
// Globals are created lazily at the top of the enclosing ModuleOp with pretty
123123
// names. Duplicates are avoided.
124124
FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp constantOp,
125+
SymbolTableCollection &symbolTables,
125126
uint64_t alignment,
126127
Attribute memorySpace = {});
127128

mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ struct BufferizationStatistics {
4545
/// additional buffer copies or set "options.copyBeforeWrite = true". The
4646
/// general bufferization entry point is `runOneShotBufferize`.
4747
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
48+
BufferizationState &bufferizationState,
4849
BufferizationStatistics *statistics = nullptr);
4950

5051
/// Bufferize the signature of `block` and its callers (i.e., ops that have the

mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state,
270270
/// Run One-Shot Bufferize on the given op: Analysis + Bufferization
271271
LogicalResult
272272
runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options,
273+
BufferizationState &state,
273274
BufferizationStatistics *statistics = nullptr);
274275

275276
} // namespace bufferization

mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ namespace bufferization {
2020
struct BufferizationStatistics;
2121
class OneShotAnalysisState;
2222
struct OneShotBufferizationOptions;
23+
class BufferizationState;
2324

2425
/// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in
2526
/// `state`.
@@ -38,6 +39,7 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
3839
/// will be inserted only to these FuncOps.
3940
llvm::LogicalResult
4041
bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
42+
BufferizationState &state,
4143
BufferizationStatistics *statistics = nullptr);
4244

4345
/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
@@ -50,7 +52,7 @@ void removeBufferizationAttributesInModule(ModuleOp moduleOp);
5052
llvm::LogicalResult runOneShotModuleBufferize(
5153
ModuleOp moduleOp,
5254
const bufferization::OneShotBufferizationOptions &options,
53-
BufferizationStatistics *statistics = nullptr);
55+
BufferizationState &state, BufferizationStatistics *statistics = nullptr);
5456

5557
} // namespace bufferization
5658
} // namespace mlir

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ namespace mlir {
3030
namespace bufferization {
3131
class AllocTensorOp;
3232
class OneShotAnalysisState;
33+
class BufferizationState;
3334
} // namespace bufferization
3435

3536
namespace linalg {

mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ struct ConstantOpInterface
2424
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
2525
arith::ConstantOp> {
2626
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
27-
const BufferizationOptions &options) const {
27+
const BufferizationOptions &options,
28+
BufferizationState &state) const {
2829
auto constantOp = cast<arith::ConstantOp>(op);
2930
auto type = dyn_cast<RankedTensorType>(constantOp.getType());
3031

@@ -46,7 +47,8 @@ struct ConstantOpInterface
4647
// Create global memory segment and replace tensor with memref pointing to
4748
// that memory segment.
4849
FailureOr<memref::GlobalOp> globalOp =
49-
getGlobalFor(constantOp, options.bufferAlignment, memorySpace);
50+
getGlobalFor(constantOp, state.getSymbolTables(),
51+
options.bufferAlignment, memorySpace);
5052
if (failed(globalOp))
5153
return failure();
5254
memref::GlobalOp globalMemref = *globalOp;
@@ -83,7 +85,8 @@ struct IndexCastOpInterface
8385
}
8486

8587
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
86-
const BufferizationOptions &options) const {
88+
const BufferizationOptions &options,
89+
BufferizationState &state) const {
8790
auto castOp = cast<arith::IndexCastOp>(op);
8891
auto resultTensorType = cast<TensorType>(castOp.getType());
8992

@@ -131,7 +134,8 @@ struct SelectOpInterface
131134
}
132135

133136
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
134-
const BufferizationOptions &options) const {
137+
const BufferizationOptions &options,
138+
BufferizationState &state) const {
135139
auto selectOp = cast<arith::SelectOp>(op);
136140
Location loc = selectOp.getLoc();
137141

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ void mlir::bufferization::populateDynamicDimSizes(
149149
//===----------------------------------------------------------------------===//
150150

151151
LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
152-
const BufferizationOptions &options) {
152+
const BufferizationOptions &options,
153+
BufferizationState &state) {
153154
OpBuilder::InsertionGuard g(rewriter);
154155
Location loc = getLoc();
155156

@@ -529,7 +530,8 @@ void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
529530
//===----------------------------------------------------------------------===//
530531

531532
LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
532-
const BufferizationOptions &options) {
533+
const BufferizationOptions &options,
534+
BufferizationState &state) {
533535
FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
534536
if (failed(buffer))
535537
return failure();
@@ -576,7 +578,8 @@ MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
576578

577579
LogicalResult
578580
MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
579-
const BufferizationOptions &options) {
581+
const BufferizationOptions &options,
582+
BufferizationState &state) {
580583
bool tensorDest = isa<TensorType>(getDest().getType());
581584
Value buffer;
582585
if (tensorDest) {
@@ -861,7 +864,8 @@ void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
861864
}
862865

863866
LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
864-
const BufferizationOptions &options) {
867+
const BufferizationOptions &options,
868+
BufferizationState &state) {
865869
// Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
866870
(void)foldToMemrefToTensorPair(rewriter, *this, options);
867871
// Note: The return value of `bufferize` indicates whether there was an error

mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,17 +83,21 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
8383
}
8484

8585
auto payloadOps = state.getPayloadOps(getTarget());
86+
BufferizationState bufferizationState;
87+
8688
for (Operation *target : payloadOps) {
8789
if (!isa<ModuleOp, FunctionOpInterface>(target))
8890
return emitSilenceableError() << "expected module or function target";
8991
auto moduleOp = dyn_cast<ModuleOp>(target);
9092
if (options.bufferizeFunctionBoundaries) {
9193
if (!moduleOp)
9294
return emitSilenceableError() << "expected module target";
93-
if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
95+
if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options,
96+
bufferizationState)))
9497
return emitSilenceableError() << "bufferization failed";
9598
} else {
96-
if (failed(bufferization::runOneShotBufferize(target, options)))
99+
if (failed(bufferization::runOneShotBufferize(target, options,
100+
bufferizationState)))
97101
return emitSilenceableError() << "bufferization failed";
98102
}
99103
}

mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,9 @@ BufferPlacementTransformationBase::BufferPlacementTransformationBase(
103103
//===----------------------------------------------------------------------===//
104104

105105
FailureOr<memref::GlobalOp>
106-
bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
107-
Attribute memorySpace) {
106+
bufferization::getGlobalFor(arith::ConstantOp constantOp,
107+
SymbolTableCollection &symbolTables,
108+
uint64_t alignment, Attribute memorySpace) {
108109
auto type = cast<RankedTensorType>(constantOp.getType());
109110
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
110111
if (!moduleOp)
@@ -127,7 +128,7 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
127128
// Create a builder without an insertion point. We will insert using the
128129
// symbol table to guarantee unique names.
129130
OpBuilder globalBuilder(moduleOp.getContext());
130-
SymbolTable symbolTable(moduleOp);
131+
SymbolTable &symbolTable = symbolTables.getSymbolTable(moduleOp);
131132

132133
// Create a pretty name.
133134
SmallString<64> buf;

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,12 @@ struct OneShotBufferizePass
161161
return signalPassFailure();
162162
}
163163

164+
BufferizationState state;
164165
BufferizationStatistics statistics;
165166
ModuleOp moduleOp = getOperation();
166167
if (opt.bufferizeFunctionBoundaries) {
167-
if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) {
168+
if (failed(
169+
runOneShotModuleBufferize(moduleOp, opt, state, &statistics))) {
168170
signalPassFailure();
169171
return;
170172
}
@@ -175,7 +177,7 @@ struct OneShotBufferizePass
175177
"'bufferize-function-boundaries'");
176178
return signalPassFailure();
177179
}
178-
if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) {
180+
if (failed(runOneShotBufferize(moduleOp, opt, state, &statistics))) {
179181
signalPassFailure();
180182
return;
181183
}
@@ -275,6 +277,7 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
275277

276278
LogicalResult bufferization::bufferizeOp(Operation *op,
277279
const BufferizationOptions &options,
280+
BufferizationState &bufferizationState,
278281
BufferizationStatistics *statistics) {
279282
if (options.copyBeforeWrite) {
280283
AnalysisState state(options);
@@ -331,7 +334,8 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
331334
<< "//===-------------------------------------------===//\n"
332335
<< "IR after bufferizing: " << nextOp->getName() << "\n");
333336
rewriter.setInsertionPoint(nextOp);
334-
if (failed(bufferizableOp.bufferize(rewriter, options))) {
337+
if (failed(
338+
bufferizableOp.bufferize(rewriter, options, bufferizationState))) {
335339
LLVM_DEBUG(llvm::dbgs()
336340
<< "failed to bufferize\n"
337341
<< "//===-------------------------------------------===//\n");

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,8 @@ struct CallOpInterface
219219
/// All function arguments are writable. It is the responsibility of the
220220
/// CallOp to insert buffer copies where necessary.
221221
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
222-
const BufferizationOptions &options) const {
222+
const BufferizationOptions &options,
223+
BufferizationState &state) const {
223224
func::CallOp callOp = cast<func::CallOp>(op);
224225

225226
// 1. Compute the result types of the new CallOp.
@@ -325,7 +326,8 @@ struct ReturnOpInterface
325326
}
326327

327328
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
328-
const BufferizationOptions &options) const {
329+
const BufferizationOptions &options,
330+
BufferizationState &state) const {
329331
#ifndef NDEBUG
330332
auto returnOp = cast<func::ReturnOp>(op);
331333
assert(isa<FuncOp>(returnOp->getParentOp()) &&
@@ -394,7 +396,8 @@ struct FuncOpInterface
394396
/// All function bbArgs are writable unless they are explicitly marked as
395397
/// read-only. Callers must insert copies when needed.
396398
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
397-
const BufferizationOptions &options) const {
399+
const BufferizationOptions &options,
400+
BufferizationState &state) const {
398401
auto funcOp = cast<FuncOp>(op);
399402
FunctionType funcType = funcOp.getFunctionType();
400403

mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,10 +1365,9 @@ LogicalResult bufferization::analyzeOp(Operation *op,
13651365
return success(!failedAnalysis);
13661366
}
13671367

1368-
LogicalResult
1369-
bufferization::runOneShotBufferize(Operation *op,
1370-
const OneShotBufferizationOptions &options,
1371-
BufferizationStatistics *statistics) {
1368+
LogicalResult bufferization::runOneShotBufferize(
1369+
Operation *op, const OneShotBufferizationOptions &options,
1370+
BufferizationState &state, BufferizationStatistics *statistics) {
13721371
// copy-before-write deactivates the analysis. It cannot be used together with
13731372
// test-analysis-only.
13741373
assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
@@ -1391,5 +1390,5 @@ bufferization::runOneShotBufferize(Operation *op,
13911390

13921391
// Bufferize the op and its nested ops. If options.copyBeforeWrite is set,
13931392
// a new buffer copy is allocated every time a buffer is written to.
1394-
return bufferizeOp(op, options, statistics);
1393+
return bufferizeOp(op, options, state, statistics);
13951394
}

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ void mlir::bufferization::removeBufferizationAttributesInModule(
506506

507507
LogicalResult mlir::bufferization::bufferizeModuleOp(
508508
ModuleOp moduleOp, const OneShotBufferizationOptions &options,
509-
BufferizationStatistics *statistics) {
509+
BufferizationState &state, BufferizationStatistics *statistics) {
510510
assert(options.bufferizeFunctionBoundaries &&
511511
"expected that function boundary bufferization is activated");
512512
IRRewriter rewriter(moduleOp.getContext());
@@ -542,10 +542,10 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
542542
// Buffer copies must be inserted before every write.
543543
OneShotBufferizationOptions updatedOptions = options;
544544
updatedOptions.copyBeforeWrite = true;
545-
if (failed(bufferizeOp(funcOp, updatedOptions, statistics)))
545+
if (failed(bufferizeOp(funcOp, updatedOptions, state, statistics)))
546546
return failure();
547547
} else {
548-
if (failed(bufferizeOp(funcOp, options, statistics)))
548+
if (failed(bufferizeOp(funcOp, options, state, statistics)))
549549
return failure();
550550
}
551551

@@ -559,7 +559,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
559559
// Functions were already bufferized.
560560
if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
561561
continue;
562-
if (failed(bufferizeOp(&op, options, statistics)))
562+
if (failed(bufferizeOp(&op, options, state, statistics)))
563563
return failure();
564564
}
565565

@@ -571,7 +571,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
571571

572572
LogicalResult mlir::bufferization::runOneShotModuleBufferize(
573573
ModuleOp moduleOp, const OneShotBufferizationOptions &options,
574-
BufferizationStatistics *statistics) {
574+
BufferizationState &state, BufferizationStatistics *statistics) {
575575
assert(options.bufferizeFunctionBoundaries &&
576576
"expected that function boundary bufferization is activated");
577577
assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
@@ -600,7 +600,7 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
600600
}
601601
if (options.testAnalysisOnly)
602602
return success();
603-
if (failed(bufferizeModuleOp(moduleOp, options, statistics)))
603+
if (failed(bufferizeModuleOp(moduleOp, options, state, statistics)))
604604
return failure();
605605
return success();
606606
}

mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ struct BranchLikeOpInterface
4343
}
4444

4545
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
46-
const BufferizationOptions &options) const {
46+
const BufferizationOptions &options,
47+
BufferizationState &state) const {
4748
// The operands of this op are bufferized together with the block signature.
4849
return success();
4950
}

0 commit comments

Comments
 (0)