Skip to content

Commit ee070d0

Browse files
[mlir][bufferization] Support custom types (1/N) (#142986)
Following the addition of TensorLike and BufferLike type interfaces (see 00eaff3), introduce minimal changes required to bufferize a custom tensor operation into a custom buffer operation. To achieve this, new interface methods are added to TensorLike type interface that abstract away the differences between existing (tensor -> memref) and custom conversions. The scope of the changes is intentionally limited (for example, BufferizableOpInterface is untouched) in order to first understand the basics and reach consensus design-wise. --- Notable changes: * mlir::bufferization::getBufferType() returns BufferLikeType (instead of BaseMemRefType) * ToTensorOp / ToBufferOp operate on TensorLikeType / BufferLikeType. Operation argument "memref" renamed to "buffer" * ToTensorOp's tensor type inferring builder is dropped (users now need to provide the tensor type explicitly)
1 parent 40d2f39 commit ee070d0

27 files changed

+389
-135
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <optional>
1818

1919
#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
20+
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
2021

2122
namespace mlir {
2223
class OpBuilder;
@@ -615,7 +616,7 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
615616
/// IR, this function can be used.
616617
///
617618
/// This function is a wrapper around BufferizableOpInterface::getBufferType.
618-
FailureOr<BaseMemRefType> getBufferType(Value value,
619+
FailureOr<BufferLikeType> getBufferType(Value value,
619620
const BufferizationOptions &options,
620621
const BufferizationState &state);
621622

@@ -629,7 +630,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
629630
/// IR, this function can be used.
630631
///
631632
/// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
632-
FailureOr<BaseMemRefType> getBufferType(Value value,
633+
FailureOr<BufferLikeType> getBufferType(Value value,
633634
const BufferizationOptions &options,
634635
const BufferizationState &state,
635636
SmallVector<Value> &invocationStack);
@@ -739,6 +740,19 @@ AliasingValueList unknownGetAliasingValues(OpOperand &opOperand);
739740
/// This is the default implementation of
740741
/// BufferizableOpInterface::hasTensorSemantics
741742
bool defaultHasTensorSemantics(Operation *op);
743+
744+
/// This is a helper function used when buffer type is guaranteed to be memref.
745+
/// It performs two actions: failure state checking and an explicit llvm::cast<>
746+
/// from the buffer-like type interface to a BaseMemRefType. This allows easier
747+
/// management of differences in C++ types at the API boundaries. Valid buffer
748+
/// type is casted to the memref type. Otherwise, the failure state is
749+
/// propagated i.e. asMemRefType(mlir::failure()) returns mlir::failure().
750+
FailureOr<BaseMemRefType> asMemRefType(FailureOr<BufferLikeType> bufferType);
751+
752+
/// This function is a free-standing helper that relies on
753+
/// bufferization::TensorLikeTypeInterface to verify the types in tensor and
754+
/// buffer worlds match.
755+
bool typesMatchAfterBufferization(Operation &op, Value tensor, Value buffer);
742756
} // namespace detail
743757

744758
} // namespace bufferization

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
1313
include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
1414
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
15+
include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
1516
include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
1617
include "mlir/Interfaces/DestinationStyleOpInterface.td"
1718
include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -386,20 +387,31 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
386387
// ToTensorOp
387388
//===----------------------------------------------------------------------===//
388389

390+
class Bufferization_TensorAndBufferMatch<string tensor, string buffer> : PredOpTrait<
391+
"specified tensor and buffer types match",
392+
CPred<
393+
"::mlir::bufferization::detail::typesMatchAfterBufferization("
394+
"$_op, $" # tensor # ", $" # buffer #")"
395+
>
396+
>;
397+
389398
def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
390399
BufferizableOpInterface,
391400
SameOperandsAndResultShape,
392401
SameOperandsAndResultElementType,
393-
AllElementTypesMatch<["memref", "result"]>
402+
Bufferization_TensorAndBufferMatch<"result", "buffer">
394403
]> {
395-
let summary = "create a tensor from a `memref`";
404+
let summary = "create a buffer-like type from a tensor-like type";
396405
let description = [{
397-
An operation that creates a tensor from a `memref`. The result value is a
398-
tensor whose shape and element type match the memref operand.
406+
An operation that creates a tensor from a buffer. The result value is a
407+
tensor-like type that must match the corresponding buffer-like operand as
408+
per TensorLikeType::verifyCompatibleBufferType(). For builtins (TensorType
409+
and BaseMemRefType), this means that shapes and element types match between
410+
the tensor and the buffer.
399411

400412
The opposite of this op is `to_buffer`. Together, these two ops are
401413
useful for source/target materializations when doing type conversions
402-
involving tensors and memrefs.
414+
involving tensors and buffers.
403415

404416
Example:
405417

@@ -441,19 +453,16 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
441453
away. However, such IR is no longer bufferizable with One-Shot Bufferize.
442454
}];
443455

444-
let arguments = (ins Arg<AnyRankedOrUnrankedMemRef,
456+
let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface,
445457
"the reference to load from",
446-
[MemReadAt<0, FullEffect>]>:$memref,
458+
[MemReadAt<0, FullEffect>]>:$buffer,
447459
UnitAttr:$restrict, UnitAttr:$writable);
448-
let results = (outs AnyTensor:$result);
460+
let results = (outs Bufferization_TensorLikeTypeInterface:$result);
449461

450462
let extraClassDeclaration = [{
451463
/// The result of a to_tensor is always a tensor.
452-
TensorType getType() {
453-
Type resultType = getResult().getType();
454-
if (::llvm::isa<TensorType>(resultType))
455-
return ::llvm::cast<TensorType>(resultType);
456-
return {};
464+
::mlir::bufferization::TensorLikeType getType() {
465+
return getResult().getType();
457466
}
458467

459468
//===------------------------------------------------------------------===//
@@ -472,22 +481,15 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
472481
FailureOr<BaseMemRefType> getBufferType(
473482
Value value, const BufferizationOptions &options,
474483
const BufferizationState &state, SmallVector<Value> &invocationStack) {
475-
return ::llvm::cast<BaseMemRefType>(getMemref().getType());
484+
return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
476485
}
477486
}];
478487

479488
let assemblyFormat = [{
480-
$memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
481-
`:` type($memref) `to` type($result)
489+
$buffer (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
490+
`:` type($buffer) `to` type($result)
482491
}];
483492

484-
let builders = [
485-
OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
486-
auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
487-
build($_builder, $_state, rtt, memref, restrict, writeable);
488-
}]>
489-
];
490-
491493
let hasCanonicalizer = 1;
492494
let hasFolder = 1;
493495
}
@@ -502,10 +504,9 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
502504
SameOperandsAndResultShape,
503505
SameOperandsAndResultElementType,
504506
Pure,
505-
AllShapesMatch<["memref", "tensor"]>,
506-
AllElementTypesMatch<["memref", "tensor"]>
507+
Bufferization_TensorAndBufferMatch<"tensor", "buffer">
507508
]> {
508-
let summary = "cast a tensor to memref";
509+
let summary = "cast a tensor-like type to buffer-like type";
509510
let description = [{
510511
An operation that returns the future buffer of a `tensor`.
511512

@@ -523,8 +524,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
523524
the returned buffer) will not be written to.
524525
}];
525526

526-
let arguments = (ins AnyTensor:$tensor, UnitAttr:$read_only);
527-
let results = (outs AnyRankedOrUnrankedMemRef:$memref);
527+
let arguments = (ins Bufferization_TensorLikeTypeInterface:$tensor, UnitAttr:$read_only);
528+
let results = (outs Bufferization_BufferLikeTypeInterface:$buffer);
528529

529530
let extraClassDeclaration = [{
530531
//===------------------------------------------------------------------===//
@@ -559,7 +560,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
559560
}];
560561

561562
let assemblyFormat = [{
562-
$tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($memref)
563+
$tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($buffer)
563564
}];
564565

565566
let hasFolder = 1;

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,15 @@
1313
// Bufferization Type Interfaces
1414
//===----------------------------------------------------------------------===//
1515

16+
#include "mlir/IR/Diagnostics.h"
1617
#include "mlir/IR/Types.h"
1718

19+
namespace mlir::bufferization {
20+
struct BufferizationOptions;
21+
class BufferizationState;
22+
class BufferLikeType;
23+
} // namespace mlir::bufferization
24+
1825
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h.inc"
1926

2027
#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,30 @@ def Bufferization_TensorLikeTypeInterface
2121
let description = [{
2222
Indicates that this type is a tensor type (similarly to a MLIR builtin
2323
tensor) for bufferization purposes.
24-
25-
The interface currently has no methods as it is used by types to opt into
26-
being supported by the bufferization procedures.
2724
}];
25+
26+
let methods = [
27+
InterfaceMethod<[{
28+
Returns a BufferLike type for this TensorLike type.
29+
}],
30+
/*retTy=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
31+
/*methodName=*/"getBufferType",
32+
/*args=*/(ins
33+
"const ::mlir::bufferization::BufferizationOptions &":$options,
34+
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError
35+
)
36+
>,
37+
InterfaceMethod<[{
38+
Returns whether a BufferLike type is compatible to this TensorLike type.
39+
The BufferLike type is assumed to be created by getBufferType().
40+
}],
41+
/*retTy=*/"::mlir::LogicalResult",
42+
/*methodName=*/"verifyCompatibleBufferType",
43+
/*args=*/(ins
44+
"::mlir::bufferization::BufferLikeType":$bufferType,
45+
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
46+
>
47+
];
2848
}
2949

3050
def Bufferization_BufferLikeTypeInterface

mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
6565
// The operand was already bufferized. Take its type directly.
6666
callerType = memrefType;
6767
} else {
68-
FailureOr<BaseMemRefType> maybeCallerType =
68+
FailureOr<BufferLikeType> maybeCallerType =
6969
bufferization::getBufferType(opOperand->get(), options, state,
7070
invocationStack);
7171
if (failed(maybeCallerType))
7272
return failure();
73-
callerType = *maybeCallerType;
73+
assert(isa<BaseMemRefType>(*maybeCallerType) && "expected memref type");
74+
callerType = cast<BaseMemRefType>(*maybeCallerType);
7475
}
7576

7677
if (!bufferType) {

mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ struct SelectOpInterface
164164
// buffers have different types, they differ only in their layout map. Cast
165165
// both of them to the most dynamic MemRef type.
166166
if (trueBuffer.getType() != falseBuffer.getType()) {
167-
auto targetType =
168-
bufferization::getBufferType(selectOp.getResult(), options, state);
167+
auto targetType = bufferization::detail::asMemRefType(
168+
bufferization::getBufferType(selectOp.getResult(), options, state));
169169
if (failed(targetType))
170170
return failure();
171171
if (trueBuffer.getType() != *targetType)
@@ -187,10 +187,12 @@ struct SelectOpInterface
187187
SmallVector<Value> &invocationStack) const {
188188
auto selectOp = cast<arith::SelectOp>(op);
189189
assert(value == selectOp.getResult() && "invalid value");
190-
auto trueType = bufferization::getBufferType(
191-
selectOp.getTrueValue(), options, state, invocationStack);
192-
auto falseType = bufferization::getBufferType(
193-
selectOp.getFalseValue(), options, state, invocationStack);
190+
auto trueType =
191+
bufferization::detail::asMemRefType(bufferization::getBufferType(
192+
selectOp.getTrueValue(), options, state, invocationStack));
193+
auto falseType =
194+
bufferization::detail::asMemRefType(bufferization::getBufferType(
195+
selectOp.getFalseValue(), options, state, invocationStack));
194196
if (failed(trueType) || failed(falseType))
195197
return failure();
196198
if (*trueType == *falseType)

0 commit comments

Comments
 (0)