Skip to content

[mlir][bufferization] Support custom types (1/N) #142986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <optional>

#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"

namespace mlir {
class OpBuilder;
Expand Down Expand Up @@ -615,7 +616,7 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
/// IR, this function can be used.
///
/// This function is a wrapper around BufferizableOpInterface::getBufferType.
FailureOr<BaseMemRefType> getBufferType(Value value,
FailureOr<BufferLikeType> getBufferType(Value value,
const BufferizationOptions &options,
const BufferizationState &state);

Expand All @@ -629,7 +630,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
/// IR, this function can be used.
///
/// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
FailureOr<BaseMemRefType> getBufferType(Value value,
FailureOr<BufferLikeType> getBufferType(Value value,
const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack);
Expand Down Expand Up @@ -739,6 +740,19 @@ AliasingValueList unknownGetAliasingValues(OpOperand &opOperand);
/// This is the default implementation of
/// BufferizableOpInterface::hasTensorSemantics
bool defaultHasTensorSemantics(Operation *op);

/// This is a helper function used when buffer type is guaranteed to be memref.
/// It performs two actions: failure state checking and an explicit llvm::cast<>
/// from the buffer-like type interface to a BaseMemRefType. This allows easier
/// management of differences in C++ types at the API boundaries. Valid buffer
/// type is casted to the memref type. Otherwise, the failure state is
/// propagated i.e. asMemRefType(mlir::failure()) returns mlir::failure().
FailureOr<BaseMemRefType> asMemRefType(FailureOr<BufferLikeType> bufferType);

/// This function is a free-standing helper that relies on
/// bufferization::TensorLikeTypeInterface to verify the types in tensor and
/// buffer worlds match.
bool typesMatchAfterBufferization(Operation &op, Value tensor, Value buffer);
} // namespace detail

} // namespace bufferization
Expand Down
59 changes: 30 additions & 29 deletions mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
Expand Down Expand Up @@ -386,20 +387,31 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
// ToTensorOp
//===----------------------------------------------------------------------===//

class Bufferization_TensorAndBufferMatch<string tensor, string buffer> : PredOpTrait<
"specified tensor and buffer types match",
CPred<
"::mlir::bufferization::detail::typesMatchAfterBufferization("
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthias-springer this would be the other problematic place. With the current approach, we want to validate that bufferization is "valid" on a tensor <-> buffer level.

The current logic checks tensor.getShape() == buffer.getShape() and tensor.getElementType() == buffer.getElementType() this practically means that TensorLike and BufferLike are ShapedType (fine by me), but even that is not enough. We've recently started to experiment with shape bounds and dynamic shapes -> getShape checking might not be sufficient.

Instead, I think we should either restore the old comparison logic (which was changed in ced2fc7) or - more likely - have this put into an interface so that it's a customization point.

But then, which interface? TensorLike? BufferLike? Since it's a type matching function, it's kind of valid to be in both.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about some kind of double dispatch? E.g., for to_tensor:

  • Op verifier checks that operand type implements BufferLikeTypeInterface.
  • Op verifier checks that result type implements TensorLikeTypeInterface.
  • Op verifier calls bufferLikeType.verifyCompatibleTensorType(tensorLikeType). The result is std::optional<LogicalResult>. success means that the type is compatible, failure means that it is incompatible, nullopt means that we don't know.
  • In case of nullopt, call tensorLikeType.verifyCompatibleBufferType(bufferLikeType).
  • If the result is still nullopt, we fail verification because neither of the two types "know" each other.

Copy link
Contributor Author

@andrey-golubev andrey-golubev Jun 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't go with double dispatch to be honest because there's majorly no difference between "tensor equivalent to buffer" vs "buffer equivalent to tensor" (we have both things which do not change between the two calls). For the time being, I guess we just put it somewhere? (either to buffer-like or to tensor-like). Perhaps with more changes it would be clearer what to do here.

Copy link
Member

@matthias-springer matthias-springer Jun 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sry, double dispatch is the wrong name. It's more like querying both interfaces.

The reason why I'm suggesting this is to support custom conversions for builtin types. The type interface implementation of RankedTensorType won't know about your custom buffer type, so it cannot verify type compatibility. But the type interface implementation of your custom buffer type can do the verification.

Copy link
Contributor Author

@andrey-golubev andrey-golubev Jun 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean now. It's actually more (or less?) straightforward:
What we have is:

  1. custom.tensor -> custom.buffer <-- this is "problem-free" (because type interface implementation is on us)
  2. tensor<..., {custom encoding}> -> memref<..., {custom layout}> <-- this is meh

there's never really a situation where we'd bufferize builtin into non-builtin or non-builtin into builtin, but the case is interesting (I'd perhaps add support for this separately if that has any use).

edit: I guess for 2. we can keep the "default" logic - which is what's on main right now - and then let's see where it brings us.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implemented the half of the querying for TensorLikeType.

"$_op, $" # tensor # ", $" # buffer #")"
>
>;

def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
BufferizableOpInterface,
SameOperandsAndResultShape,
SameOperandsAndResultElementType,
AllElementTypesMatch<["memref", "result"]>
Bufferization_TensorAndBufferMatch<"result", "buffer">
]> {
let summary = "create a tensor from a `memref`";
let summary = "create a buffer-like type from a tensor-like type";
let description = [{
An operation that creates a tensor from a `memref`. The result value is a
tensor whose shape and element type match the memref operand.
An operation that creates a tensor from a buffer. The result value is a
tensor-like type that must match the corresponding buffer-like operand as
per TensorLikeType::verifyCompatibleBufferType(). For builtins (TensorType
and BaseMemRefType), this means that shapes and element types match between
the tensor and the buffer.

The opposite of this op is `to_buffer`. Together, these two ops are
useful for source/target materializations when doing type conversions
involving tensors and memrefs.
involving tensors and buffers.

Example:

Expand Down Expand Up @@ -441,19 +453,16 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
away. However, such IR is no longer bufferizable with One-Shot Bufferize.
}];

let arguments = (ins Arg<AnyRankedOrUnrankedMemRef,
let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface,
"the reference to load from",
[MemReadAt<0, FullEffect>]>:$memref,
[MemReadAt<0, FullEffect>]>:$buffer,
UnitAttr:$restrict, UnitAttr:$writable);
let results = (outs AnyTensor:$result);
let results = (outs Bufferization_TensorLikeTypeInterface:$result);

let extraClassDeclaration = [{
/// The result of a to_tensor is always a tensor.
TensorType getType() {
Type resultType = getResult().getType();
if (::llvm::isa<TensorType>(resultType))
return ::llvm::cast<TensorType>(resultType);
return {};
::mlir::bufferization::TensorLikeType getType() {
return getResult().getType();
}

//===------------------------------------------------------------------===//
Expand All @@ -472,22 +481,15 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
FailureOr<BaseMemRefType> getBufferType(
Value value, const BufferizationOptions &options,
const BufferizationState &state, SmallVector<Value> &invocationStack) {
return ::llvm::cast<BaseMemRefType>(getMemref().getType());
return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
}
}];

let assemblyFormat = [{
$memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
`:` type($memref) `to` type($result)
$buffer (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
`:` type($buffer) `to` type($result)
}];

let builders = [
OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
build($_builder, $_state, rtt, memref, restrict, writeable);
}]>
];

let hasCanonicalizer = 1;
let hasFolder = 1;
}
Expand All @@ -502,10 +504,9 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
SameOperandsAndResultShape,
SameOperandsAndResultElementType,
Pure,
AllShapesMatch<["memref", "tensor"]>,
AllElementTypesMatch<["memref", "tensor"]>
Bufferization_TensorAndBufferMatch<"tensor", "buffer">
]> {
let summary = "cast a tensor to memref";
let summary = "cast a tensor-like type to buffer-like type";
let description = [{
An operation that returns the future buffer of a `tensor`.

Expand All @@ -523,8 +524,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
the returned buffer) will not be written to.
}];

let arguments = (ins AnyTensor:$tensor, UnitAttr:$read_only);
let results = (outs AnyRankedOrUnrankedMemRef:$memref);
let arguments = (ins Bufferization_TensorLikeTypeInterface:$tensor, UnitAttr:$read_only);
let results = (outs Bufferization_BufferLikeTypeInterface:$buffer);

let extraClassDeclaration = [{
//===------------------------------------------------------------------===//
Expand Down Expand Up @@ -559,7 +560,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
}];

let assemblyFormat = [{
$tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($memref)
$tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($buffer)
}];

let hasFolder = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,15 @@
// Bufferization Type Interfaces
//===----------------------------------------------------------------------===//

#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Types.h"

namespace mlir::bufferization {
struct BufferizationOptions;
class BufferizationState;
class BufferLikeType;
} // namespace mlir::bufferization

#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h.inc"

#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,30 @@ def Bufferization_TensorLikeTypeInterface
let description = [{
Indicates that this type is a tensor type (similarly to a MLIR builtin
tensor) for bufferization purposes.

The interface currently has no methods as it is used by types to opt into
being supported by the bufferization procedures.
}];

let methods = [
InterfaceMethod<[{
Returns a BufferLike type for this TensorLike type.
}],
/*retTy=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
/*methodName=*/"getBufferType",
/*args=*/(ins
"const ::mlir::bufferization::BufferizationOptions &":$options,
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError
)
>,
InterfaceMethod<[{
Returns whether a BufferLike type is compatible to this TensorLike type.
The BufferLike type is assumed to be created by getBufferType().
}],
/*retTy=*/"::mlir::LogicalResult",
/*methodName=*/"verifyCompatibleBufferType",
/*args=*/(ins
"::mlir::bufferization::BufferLikeType":$bufferType,
"::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
>
];
}

def Bufferization_BufferLikeTypeInterface
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
// The operand was already bufferized. Take its type directly.
callerType = memrefType;
} else {
FailureOr<BaseMemRefType> maybeCallerType =
FailureOr<BufferLikeType> maybeCallerType =
bufferization::getBufferType(opOperand->get(), options, state,
invocationStack);
if (failed(maybeCallerType))
return failure();
callerType = *maybeCallerType;
assert(isa<BaseMemRefType>(*maybeCallerType) && "expected memref type");
callerType = cast<BaseMemRefType>(*maybeCallerType);
}

if (!bufferType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ struct SelectOpInterface
// buffers have different types, they differ only in their layout map. Cast
// both of them to the most dynamic MemRef type.
if (trueBuffer.getType() != falseBuffer.getType()) {
auto targetType =
bufferization::getBufferType(selectOp.getResult(), options, state);
auto targetType = bufferization::detail::asMemRefType(
bufferization::getBufferType(selectOp.getResult(), options, state));
if (failed(targetType))
return failure();
if (trueBuffer.getType() != *targetType)
Expand All @@ -187,10 +187,12 @@ struct SelectOpInterface
SmallVector<Value> &invocationStack) const {
auto selectOp = cast<arith::SelectOp>(op);
assert(value == selectOp.getResult() && "invalid value");
auto trueType = bufferization::getBufferType(
selectOp.getTrueValue(), options, state, invocationStack);
auto falseType = bufferization::getBufferType(
selectOp.getFalseValue(), options, state, invocationStack);
auto trueType =
bufferization::detail::asMemRefType(bufferization::getBufferType(
selectOp.getTrueValue(), options, state, invocationStack));
auto falseType =
bufferization::detail::asMemRefType(bufferization::getBufferType(
selectOp.getFalseValue(), options, state, invocationStack));
if (failed(trueType) || failed(falseType))
return failure();
if (*trueType == *falseType)
Expand Down
Loading
Loading