Skip to content

Commit 63cb6af

Browse files
authored
[MLIR] Add bufferization state to getBufferType and resolveConflicts interface methods (#141466)
The PR continues the work started in #141019 by adding the `BufferizationState` class also to the `getBufferType` and `resolveConflicts` interface methods, together with the additional support functions that are used throughout the bufferization infrastructure.
1 parent b577438 commit 63cb6af

22 files changed

+260
-182
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -598,13 +598,14 @@ class BufferizationState {
598598
FailureOr<Value>
599599
allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue,
600600
const BufferizationOptions &options,
601-
bool copy = true);
601+
const BufferizationState &state, bool copy = true);
602602

603603
/// Lookup the buffer for the given value. If the value was not bufferized
604604
/// yet, wrap it in a ToBufferOp. Otherwise, it is the result of a ToTensorOp,
605605
/// from which the memref operand is returned.
606606
FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
607-
const BufferizationOptions &options);
607+
const BufferizationOptions &options,
608+
const BufferizationState &state);
608609

609610
/// Return the buffer type for a given Value (tensor) after bufferization
610611
/// without bufferizing any IR.
@@ -615,7 +616,8 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
615616
///
616617
/// This function is a wrapper around BufferizableOpInterface::getBufferType.
617618
FailureOr<BaseMemRefType> getBufferType(Value value,
618-
const BufferizationOptions &options);
619+
const BufferizationOptions &options,
620+
const BufferizationState &state);
619621

620622
/// Return the buffer type for a given Value (tensor) after bufferization
621623
/// without bufferizing any IR. This function (and not the other overload
@@ -629,6 +631,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
629631
/// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
630632
FailureOr<BaseMemRefType> getBufferType(Value value,
631633
const BufferizationOptions &options,
634+
const BufferizationState &state,
632635
SmallVector<Value> &invocationStack);
633636

634637
/// Return "true" if the given op has tensor semantics and should be bufferized.
@@ -709,6 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
709712
/// places.
710713
FailureOr<BaseMemRefType>
711714
defaultGetBufferType(Value value, const BufferizationOptions &options,
715+
const BufferizationState &state,
712716
SmallVector<Value> &invocationStack);
713717

714718
/// This is the default implementation of

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -381,13 +381,14 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
381381
/*retType=*/"::llvm::LogicalResult",
382382
/*methodName=*/"resolveConflicts",
383383
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
384-
"const ::mlir::bufferization::AnalysisState &":$state),
384+
"const ::mlir::bufferization::AnalysisState &":$analysisState,
385+
"const ::mlir::bufferization::BufferizationState &":$bufferizationState),
385386
/*methodBody=*/"",
386387
/*defaultImplementation=*/[{
387388
auto bufferizableOp =
388389
::llvm::cast<BufferizableOpInterface>($_op.getOperation());
389390
return bufferizableOp.resolveTensorOpOperandConflicts(
390-
rewriter, state);
391+
rewriter, analysisState, bufferizationState);
391392
}]
392393
>,
393394
InterfaceMethod<
@@ -528,6 +529,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
528529
/*methodName=*/"getBufferType",
529530
/*args=*/(ins "::mlir::Value":$value,
530531
"const ::mlir::bufferization::BufferizationOptions &":$options,
532+
"const ::mlir::bufferization::BufferizationState &":$state,
531533
"::llvm::SmallVector<::mlir::Value> &":$invocationStack),
532534
/*methodBody=*/"",
533535
/*defaultImplementation=*/[{
@@ -536,7 +538,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
536538
assert(invocationStack.back() == value &&
537539
"inconsistant invocation stack");
538540
return ::mlir::bufferization::detail::defaultGetBufferType(
539-
value, options, invocationStack);
541+
value, options, state, invocationStack);
540542
}]
541543
>,
542544
InterfaceMethod<
@@ -621,7 +623,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
621623
/// form of `bufferization.alloc_tensor` ops.
622624
::llvm::LogicalResult resolveTensorOpOperandConflicts(
623625
::mlir::RewriterBase &rewriter,
624-
const ::mlir::bufferization::AnalysisState &state);
626+
const ::mlir::bufferization::AnalysisState &analysisState,
627+
const ::mlir::bufferization::BufferizationState &bufferizationState);
625628

626629
/// Return `true` if the given OpOperand creates an alias but does neither
627630
/// read nor write. This implies that `bufferizesToMemoryRead` and

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
112112

113113
FailureOr<BaseMemRefType> getBufferType(
114114
Value value, const BufferizationOptions &options,
115+
const BufferizationState &state,
115116
SmallVector<Value> &invocationStack);
116117

117118
RankedTensorType getType() {
@@ -471,7 +472,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
471472

472473
FailureOr<BaseMemRefType> getBufferType(
473474
Value value, const BufferizationOptions &options,
474-
SmallVector<Value> &invocationStack) {
475+
const BufferizationState &state, SmallVector<Value> &invocationStack) {
475476
return ::llvm::cast<BaseMemRefType>(getMemref().getType());
476477
}
477478
}];

mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
3434

3535
FailureOr<BaseMemRefType>
3636
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
37+
const BufferizationState &state,
3738
SmallVector<Value> &invocationStack) const {
3839
// Note: The user may want to override this function for OpResults in
3940
// case the bufferized result type is different from the bufferized type of
4041
// the aliasing OpOperand (if any).
4142
if (isa<OpResult>(value))
42-
return bufferization::detail::defaultGetBufferType(value, options,
43+
return bufferization::detail::defaultGetBufferType(value, options, state,
4344
invocationStack);
4445

4546
// Compute the buffer type of the block argument by computing the bufferized
@@ -65,7 +66,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
6566
callerType = memrefType;
6667
} else {
6768
FailureOr<BaseMemRefType> maybeCallerType =
68-
bufferization::getBufferType(opOperand->get(), options,
69+
bufferization::getBufferType(opOperand->get(), options, state,
6970
invocationStack);
7071
if (failed(maybeCallerType))
7172
return failure();
@@ -81,9 +82,9 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
8182
if (bufferType == callerType)
8283
continue;
8384

84-
// If the computed buffer type does not match the computed buffer type
85-
// of the earlier forwarded operands, fall back to a buffer type with a
86-
// fully dynamic layout map.
85+
// If the computed buffer type does not match the computed buffer type
86+
// of the earlier forwarded operands, fall back to a buffer type with a
87+
// fully dynamic layout map.
8788
#ifndef NDEBUG
8889
if (auto rankedTensorType = dyn_cast<RankedTensorType>(tensorType)) {
8990
assert(bufferType.hasRank() && callerType.hasRank() &&

mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
6262
/// `BufferizableOpInterface`. The buffer types of tensor block arguments are
6363
/// computed with `BufferizableOpIntercace::getBufferType`.
6464
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
65-
const BufferizationOptions &options);
65+
const BufferizationOptions &options,
66+
BufferizationState &state);
6667

6768
} // namespace bufferization
6869
} // namespace mlir

mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,15 @@ void hoistBuffersFromLoops(Operation *op);
7575
/// additional buffer allocations.
7676
LogicalResult insertTensorCopies(Operation *op,
7777
const OneShotBufferizationOptions &options,
78+
const BufferizationState &bufferizationState,
7879
BufferizationStatistics *statistics = nullptr);
7980

8081
/// Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
8182
/// After applying this transform, the IR can be bufferized without inserting
8283
/// additional buffer allocations.
83-
LogicalResult insertTensorCopies(Operation *op, const AnalysisState &state);
84+
LogicalResult insertTensorCopies(Operation *op,
85+
const AnalysisState &analysisState,
86+
const BufferizationState &bufferizationState);
8487

8588
/// Populate patterns to lower tensor.empty ops to bufferization.alloc_tensor
8689
/// ops.

mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ struct IndexCastOpInterface
9090
auto castOp = cast<arith::IndexCastOp>(op);
9191
auto resultTensorType = cast<TensorType>(castOp.getType());
9292

93-
FailureOr<Value> source = getBuffer(rewriter, castOp.getIn(), options);
93+
FailureOr<Value> source =
94+
getBuffer(rewriter, castOp.getIn(), options, state);
9495
if (failed(source))
9596
return failure();
9697
auto sourceType = cast<BaseMemRefType>(source->getType());
@@ -151,9 +152,9 @@ struct SelectOpInterface
151152
// the moment (one for each tensor). When copying the op result, only one
152153
// copy would be needed.
153154
FailureOr<Value> maybeTrueBuffer =
154-
getBuffer(rewriter, selectOp.getTrueValue(), options);
155+
getBuffer(rewriter, selectOp.getTrueValue(), options, state);
155156
FailureOr<Value> maybeFalseBuffer =
156-
getBuffer(rewriter, selectOp.getFalseValue(), options);
157+
getBuffer(rewriter, selectOp.getFalseValue(), options, state);
157158
if (failed(maybeTrueBuffer) || failed(maybeFalseBuffer))
158159
return failure();
159160
Value trueBuffer = *maybeTrueBuffer;
@@ -164,7 +165,7 @@ struct SelectOpInterface
164165
// both of them to the most dynamic MemRef type.
165166
if (trueBuffer.getType() != falseBuffer.getType()) {
166167
auto targetType =
167-
bufferization::getBufferType(selectOp.getResult(), options);
168+
bufferization::getBufferType(selectOp.getResult(), options, state);
168169
if (failed(targetType))
169170
return failure();
170171
if (trueBuffer.getType() != *targetType)
@@ -182,13 +183,14 @@ struct SelectOpInterface
182183

183184
FailureOr<BaseMemRefType>
184185
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
186+
const BufferizationState &state,
185187
SmallVector<Value> &invocationStack) const {
186188
auto selectOp = cast<arith::SelectOp>(op);
187189
assert(value == selectOp.getResult() && "invalid value");
188-
auto trueType = bufferization::getBufferType(selectOp.getTrueValue(),
189-
options, invocationStack);
190-
auto falseType = bufferization::getBufferType(selectOp.getFalseValue(),
191-
options, invocationStack);
190+
auto trueType = bufferization::getBufferType(
191+
selectOp.getTrueValue(), options, state, invocationStack);
192+
auto falseType = bufferization::getBufferType(
193+
selectOp.getFalseValue(), options, state, invocationStack);
192194
if (failed(trueType) || failed(falseType))
193195
return failure();
194196
if (*trueType == *falseType)

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,8 @@ Operation *bufferization::getOwnerOfValue(Value value) {
165165
/// allocated.
166166
FailureOr<Value> bufferization::allocateTensorForShapedValue(
167167
OpBuilder &b, Location loc, Value shapedValue,
168-
const BufferizationOptions &options, bool copy) {
168+
const BufferizationOptions &options, const BufferizationState &state,
169+
bool copy) {
169170
Value tensor;
170171
if (llvm::isa<RankedTensorType>(shapedValue.getType())) {
171172
tensor = shapedValue;
@@ -210,7 +211,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
210211
// Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
211212
if (copy)
212213
return allocTensorOp.getResult();
213-
FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
214+
FailureOr<BaseMemRefType> copyBufferType =
215+
getBufferType(tensor, options, state);
214216
if (failed(copyBufferType))
215217
return failure();
216218
std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
@@ -222,7 +224,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
222224
}
223225

224226
LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
225-
RewriterBase &rewriter, const AnalysisState &state) {
227+
RewriterBase &rewriter, const AnalysisState &analysisState,
228+
const BufferizationState &bufferizationState) {
226229
OpBuilder::InsertionGuard g(rewriter);
227230
Operation *op = getOperation();
228231
SmallVector<OpOperand *> outOfPlaceOpOperands;
@@ -235,16 +238,18 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
235238
Type operandType = opOperand.get().getType();
236239
if (!llvm::isa<TensorType>(operandType))
237240
continue;
238-
if (state.isInPlace(opOperand))
241+
if (analysisState.isInPlace(opOperand))
239242
continue;
240243
if (llvm::isa<UnrankedTensorType>(operandType))
241244
return op->emitError("copying of unranked tensors is not implemented");
242245

243-
AliasingValueList aliasingValues = state.getAliasingValues(opOperand);
246+
AliasingValueList aliasingValues =
247+
analysisState.getAliasingValues(opOperand);
244248
if (aliasingValues.getNumAliases() == 1 &&
245249
isa<OpResult>(aliasingValues.getAliases()[0].value) &&
246-
!state.bufferizesToMemoryWrite(opOperand) &&
247-
state.getAliasingOpOperands(aliasingValues.getAliases()[0].value)
250+
!analysisState.bufferizesToMemoryWrite(opOperand) &&
251+
analysisState
252+
.getAliasingOpOperands(aliasingValues.getAliases()[0].value)
248253
.getNumAliases() == 1 &&
249254
!isa<UnrankedTensorType>(
250255
aliasingValues.getAliases()[0].value.getType())) {
@@ -256,12 +261,12 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
256261
// cannot be copied at the moment).
257262
Value value = aliasingValues.getAliases()[0].value;
258263
outOfPlaceValues.push_back(value);
259-
if (!state.canOmitTensorCopy(opOperand))
264+
if (!analysisState.canOmitTensorCopy(opOperand))
260265
copiedOpValues.insert(value);
261266
} else {
262267
// In all other cases, make a copy of the OpOperand.
263268
outOfPlaceOpOperands.push_back(&opOperand);
264-
if (!state.canOmitTensorCopy(opOperand))
269+
if (!analysisState.canOmitTensorCopy(opOperand))
265270
copiedOpOperands.insert(&opOperand);
266271
}
267272
}
@@ -270,8 +275,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
270275
rewriter.setInsertionPoint(op);
271276
for (OpOperand *opOperand : outOfPlaceOpOperands) {
272277
FailureOr<Value> copy = allocateTensorForShapedValue(
273-
rewriter, op->getLoc(), opOperand->get(), state.getOptions(),
274-
copiedOpOperands.contains(opOperand));
278+
rewriter, op->getLoc(), opOperand->get(), analysisState.getOptions(),
279+
bufferizationState, copiedOpOperands.contains(opOperand));
275280
if (failed(copy))
276281
return failure();
277282
rewriter.modifyOpInPlace(op, [&]() { opOperand->set(*copy); });
@@ -281,8 +286,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
281286
rewriter.setInsertionPointAfter(op);
282287
for (Value value : outOfPlaceValues) {
283288
FailureOr<Value> copy = allocateTensorForShapedValue(
284-
rewriter, op->getLoc(), value, state.getOptions(),
285-
copiedOpValues.count(value));
289+
rewriter, op->getLoc(), value, analysisState.getOptions(),
290+
bufferizationState, copiedOpValues.count(value));
286291
if (failed(copy))
287292
return failure();
288293
SmallVector<OpOperand *> uses = llvm::to_vector(
@@ -665,7 +670,8 @@ static void ensureToBufferOpIsValid(Value tensor, Type memrefType) {
665670
}
666671

667672
FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
668-
const BufferizationOptions &options) {
673+
const BufferizationOptions &options,
674+
const BufferizationState &state) {
669675
#ifndef NDEBUG
670676
auto tensorType = llvm::dyn_cast<TensorType>(value.getType());
671677
assert(tensorType && "unexpected non-tensor type");
@@ -678,7 +684,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
678684
// Insert to_buffer op.
679685
OpBuilder::InsertionGuard g(rewriter);
680686
setInsertionPointAfter(rewriter, value);
681-
FailureOr<BaseMemRefType> memrefType = getBufferType(value, options);
687+
FailureOr<BaseMemRefType> memrefType = getBufferType(value, options, state);
682688
if (failed(memrefType))
683689
return failure();
684690
ensureToBufferOpIsValid(value, *memrefType);
@@ -689,14 +695,16 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
689695

690696
/// Return the buffer type for a given Value (tensor) after bufferization.
691697
FailureOr<BaseMemRefType>
692-
bufferization::getBufferType(Value value, const BufferizationOptions &options) {
698+
bufferization::getBufferType(Value value, const BufferizationOptions &options,
699+
const BufferizationState &state) {
693700
SmallVector<Value> invocationStack;
694-
return getBufferType(value, options, invocationStack);
701+
return getBufferType(value, options, state, invocationStack);
695702
}
696703

697704
/// Return the buffer type for a given Value (tensor) after bufferization.
698705
FailureOr<BaseMemRefType>
699706
bufferization::getBufferType(Value value, const BufferizationOptions &options,
707+
const BufferizationState &state,
700708
SmallVector<Value> &invocationStack) {
701709
assert(llvm::isa<TensorType>(value.getType()) &&
702710
"unexpected non-tensor type");
@@ -708,7 +716,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
708716
Operation *op = getOwnerOfValue(value);
709717
auto bufferizableOp = options.dynCastBufferizableOp(op);
710718
if (bufferizableOp)
711-
return bufferizableOp.getBufferType(value, options, invocationStack);
719+
return bufferizableOp.getBufferType(value, options, state, invocationStack);
712720

713721
// Op is not bufferizable.
714722
auto memSpace =
@@ -944,6 +952,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
944952

945953
FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
946954
Value value, const BufferizationOptions &options,
955+
const BufferizationState &bufferizationState,
947956
SmallVector<Value> &invocationStack) {
948957
assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
949958

@@ -954,14 +963,15 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
954963
// Value is an OpResult.
955964
Operation *op = getOwnerOfValue(value);
956965
auto opResult = llvm::cast<OpResult>(value);
957-
AnalysisState state(options);
958-
AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
966+
AnalysisState analysisState(options);
967+
AliasingOpOperandList aliases = analysisState.getAliasingOpOperands(opResult);
959968
if (aliases.getNumAliases() > 0 &&
960969
aliases.getAliases()[0].relation == BufferRelation::Equivalent) {
961970
// If the OpResult has an equivalent OpOperand, both OpResult and
962971
// OpOperand bufferize to the exact same buffer type.
963972
Value equivalentOperand = aliases.getAliases().front().opOperand->get();
964-
return getBufferType(equivalentOperand, options, invocationStack);
973+
return getBufferType(equivalentOperand, options, bufferizationState,
974+
invocationStack);
965975
}
966976

967977
// If we do not know the memory space and there is no default memory space,

0 commit comments

Comments
 (0)