-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][bufferization] Support custom types (1/N) #142986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][bufferization] Support custom types (1/N) #142986
Conversation
@llvm/pr-subscribers-mlir-shape @llvm/pr-subscribers-mlir-bufferization Author: Andrei Golubev (andrey-golubev) ChangesFollowing the introduction of TensorLike and BufferLike type interfaces (see 00eaff3), introduce minimal changes required to bufferize a custom tensor operation into a custom buffer operation. To achieve this, a new conversion dialect interface is added that abstracts away the differences between existing (tensor -> memref) and custom conversions. The scope of the changes is intentionally limited (for example, BufferizableOpInterface is untouched) in order to first understand the basics and reach consensus design-wise. Patch is 49.87 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/142986.diff 19 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index adccbef754ec5..8390da956444d 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -17,6 +17,7 @@
#include <optional>
#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
namespace mlir {
class OpBuilder;
@@ -615,7 +616,7 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
/// IR, this function can be used.
///
/// This function is a wrapper around BufferizableOpInterface::getBufferType.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
const BufferizationOptions &options,
const BufferizationState &state);
@@ -629,7 +630,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
/// IR, this function can be used.
///
/// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack);
@@ -738,6 +739,18 @@ AliasingValueList unknownGetAliasingValues(OpOperand &opOperand);
/// This is the default implementation of
/// BufferizableOpInterface::hasTensorSemantics
bool defaultHasTensorSemantics(Operation *op);
+
+/// This is a helper function used when buffer type is guaranteed to be memref.
+FailureOr<BaseMemRefType> castToMemRef(FailureOr<BufferLikeType> bufferType);
+
+/// This function is a free-standing helper that relies on
+/// bufferization::ConversionInterface to verify the types in tensor and buffer
+/// worlds match.
+bool typesMatchAfterBufferization(Operation &op, Value tensor, Value buffer);
+
+/// This function is a free-standing helper that relies on
+/// bufferization::ConversionInterface to perform the conversion.
+Type getTensorFromBuffer(Type buffer);
} // namespace detail
} // namespace bufferization
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h
new file mode 100644
index 0000000000000..4164d1dcb9ea6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h
@@ -0,0 +1,72 @@
+//===- BufferizationConversionInterface.h - Dialect Interface ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONCONVERSIONINTERFACE_H_
+#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONCONVERSIONINTERFACE_H_
+
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
+#include "mlir/IR/DialectInterface.h"
+
+namespace mlir {
+namespace bufferization {
+
+/// This class defines a virtual interface for conversions between tensor-like
+/// and buffer-like types.
+struct ConversionDialectInterface
+ : DialectInterface::Base<ConversionDialectInterface> {
+ using Base::Base;
+
+ /// Hook to customize tensor-like -> buffer-like conversion within a given
+ /// dialect. Returns a buffer-like type for the specific tensor-like type.
+ virtual FailureOr<BufferLikeType> getBufferType(
+ Value value, const BufferizationOptions &options,
+ const BufferizationState &state,
+ function_ref<InFlightDiagnostic(const Twine &)> emitError) const = 0;
+
+ /// Hook to customize type checking between tensor-like and buffer-like types.
+ /// Given tensor `T` and buffer `B = getBufferType(T, ...)`, the call to
+ /// `typesMatch(T, B)` must return true.
+ virtual LogicalResult typesMatch(
+ TensorLikeType tensor, BufferLikeType buffer,
+ function_ref<InFlightDiagnostic(const Twine &)> emitError) const = 0;
+
+ /// Hook to customize buffer-like -> tensor-like conversion, which is the
+ /// opposite of bufferization.
+ virtual TensorLikeType getTensorFromBuffer(BufferLikeType buffer) const = 0;
+};
+
+/// Interface collection for conversion between tensor-like and buffer-like
+/// types, dispatches to a concrete interface implementation based on the
+/// dialect to which the given type belongs.
+struct ConversionInterface
+ : DialectInterfaceCollection<ConversionDialectInterface> {
+ using Base::Base;
+
+ /// Dispatches to ConversionDialectInterface::getBufferType() of the dialect
+ /// associated with the value type.
+ FailureOr<BufferLikeType> getBufferType(
+ Value value, const BufferizationOptions &options,
+ const BufferizationState &state,
+ function_ref<InFlightDiagnostic(const Twine &)> emitError) const;
+
+ /// Dispatches to ConversionDialectInterface::typesMatch() of the dialect
+ /// associated with the value type.
+ LogicalResult
+ typesMatch(TensorLikeType tensor, BufferLikeType buffer,
+ function_ref<InFlightDiagnostic(const Twine &)> emitError) const;
+
+ /// Dispatches to ConversionDialectInterface::getTensorFromBuffer() of the
+ /// dialect associated with the value type.
+ TensorLikeType getTensorFromBuffer(BufferLikeType buffer) const;
+};
+
+} // namespace bufferization
+} // namespace mlir
+
+#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 3d4dcdee2663b..277d56bc3f647 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -12,6 +12,7 @@
include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
+include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -387,20 +388,28 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
// ToTensorOp
//===----------------------------------------------------------------------===//
+class Bufferization_TensorAndBufferMatch<string tensor, string buffer> : PredOpTrait<
+ "specified tensor and buffer types match",
+ CPred<
+ "::mlir::bufferization::detail::typesMatchAfterBufferization("
+ "$_op, $" # tensor # ", $" # buffer #")"
+ >
+>;
+
def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
BufferizableOpInterface,
SameOperandsAndResultShape,
SameOperandsAndResultElementType,
- AllElementTypesMatch<["memref", "result"]>
+ Bufferization_TensorAndBufferMatch<"result", "buffer">
]> {
- let summary = "create a tensor from a `memref`";
+ let summary = "create a buffer-like type from a tensor-like type";
let description = [{
- An operation that creates a tensor from a `memref`. The result value is a
- tensor whose shape and element type match the memref operand.
+ An operation that creates a tensor from a buffer. The result value is a
+ tensor-like type whose shape and element type match the buffer-like operand.
The opposite of this op is `to_buffer`. Together, these two ops are
useful for source/target materializations when doing type conversions
- involving tensors and memrefs.
+ involving tensors and buffers.
Example:
@@ -442,11 +451,11 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
away. However, such IR is no longer bufferizable with One-Shot Bufferize.
}];
- let arguments = (ins Arg<AnyRankedOrUnrankedMemRef,
+ let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface,
"the reference to load from",
- [MemReadAt<0, FullEffect>]>:$memref,
+ [MemReadAt<0, FullEffect>]>:$buffer,
UnitAttr:$restrict, UnitAttr:$writable);
- let results = (outs AnyTensor:$result);
+ let results = (outs Bufferization_TensorLikeTypeInterface:$result);
let extraClassDeclaration = [{
/// The result of a to_tensor is always a tensor.
@@ -473,19 +482,19 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
FailureOr<BaseMemRefType> getBufferType(
Value value, const BufferizationOptions &options,
const BufferizationState &state, SmallVector<Value> &invocationStack) {
- return ::llvm::cast<BaseMemRefType>(getMemref().getType());
+ return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
}
}];
let assemblyFormat = [{
- $memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
- `:` type($memref) `to` type($result)
+ $buffer (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
+ `:` type($buffer) `to` type($result)
}];
let builders = [
- OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
- auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
- build($_builder, $_state, rtt, memref, restrict, writeable);
+ OpBuilder<(ins "Value":$buffer, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
+ auto rtt = bufferization::detail::getTensorFromBuffer(buffer.getType());
+ build($_builder, $_state, rtt, buffer, restrict, writeable);
}]>
];
@@ -503,10 +512,9 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
SameOperandsAndResultShape,
SameOperandsAndResultElementType,
Pure,
- AllShapesMatch<["memref", "tensor"]>,
- AllElementTypesMatch<["memref", "tensor"]>
+ Bufferization_TensorAndBufferMatch<"tensor", "buffer">
]> {
- let summary = "cast a tensor to memref";
+ let summary = "cast a tensor-like type to buffer-like type";
let description = [{
An operation that returns the future buffer of a `tensor`.
@@ -524,8 +532,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
the returned buffer) will not be written to.
}];
- let arguments = (ins AnyTensor:$tensor, UnitAttr:$read_only);
- let results = (outs AnyRankedOrUnrankedMemRef:$memref);
+ let arguments = (ins Bufferization_TensorLikeTypeInterface:$tensor, UnitAttr:$read_only);
+ let results = (outs Bufferization_BufferLikeTypeInterface:$buffer);
let extraClassDeclaration = [{
//===------------------------------------------------------------------===//
@@ -560,7 +568,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
}];
let assemblyFormat = [{
- $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($memref)
+ $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($buffer)
}];
let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index a441b8b66659e..f56c10555f02c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -65,12 +65,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
// The operand was already bufferized. Take its type directly.
callerType = memrefType;
} else {
- FailureOr<BaseMemRefType> maybeCallerType =
+ FailureOr<BufferLikeType> maybeCallerType =
bufferization::getBufferType(opOperand->get(), options, state,
invocationStack);
if (failed(maybeCallerType))
return failure();
- callerType = *maybeCallerType;
+ assert(isa<BaseMemRefType>(*maybeCallerType) && "expected memref type");
+ callerType = cast<BaseMemRefType>(*maybeCallerType);
}
if (!bufferType) {
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index a57d58ab28d28..021a557f68b4b 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -164,8 +164,8 @@ struct SelectOpInterface
// buffers have different types, they differ only in their layout map. Cast
// both of them to the most dynamic MemRef type.
if (trueBuffer.getType() != falseBuffer.getType()) {
- auto targetType =
- bufferization::getBufferType(selectOp.getResult(), options, state);
+ auto targetType = bufferization::detail::castToMemRef(
+ bufferization::getBufferType(selectOp.getResult(), options, state));
if (failed(targetType))
return failure();
if (trueBuffer.getType() != *targetType)
@@ -187,10 +187,12 @@ struct SelectOpInterface
SmallVector<Value> &invocationStack) const {
auto selectOp = cast<arith::SelectOp>(op);
assert(value == selectOp.getResult() && "invalid value");
- auto trueType = bufferization::getBufferType(
- selectOp.getTrueValue(), options, state, invocationStack);
- auto falseType = bufferization::getBufferType(
- selectOp.getFalseValue(), options, state, invocationStack);
+ auto trueType =
+ bufferization::detail::castToMemRef(bufferization::getBufferType(
+ selectOp.getTrueValue(), options, state, invocationStack));
+ auto falseType =
+ bufferization::detail::castToMemRef(bufferization::getBufferType(
+ selectOp.getFalseValue(), options, state, invocationStack));
if (failed(trueType) || failed(falseType))
return failure();
if (*trueType == *falseType)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 1d6e1bdaf80f5..d00605a7b9a17 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -211,8 +212,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
// Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
if (copy)
return allocTensorOp.getResult();
- FailureOr<BaseMemRefType> copyBufferType =
- getBufferType(tensor, options, state);
+ auto copyBufferType =
+ detail::castToMemRef(getBufferType(tensor, options, state));
if (failed(copyBufferType))
return failure();
std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
@@ -673,28 +674,28 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
const BufferizationOptions &options,
const BufferizationState &state) {
#ifndef NDEBUG
- auto tensorType = llvm::dyn_cast<TensorType>(value.getType());
+ auto tensorType = llvm::dyn_cast<TensorLikeType>(value.getType());
assert(tensorType && "unexpected non-tensor type");
#endif // NDEBUG
// Replace "%t = to_tensor %m" with %m.
if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
- return toTensorOp.getMemref();
+ return toTensorOp.getBuffer();
// Insert to_buffer op.
OpBuilder::InsertionGuard g(rewriter);
setInsertionPointAfter(rewriter, value);
- FailureOr<BaseMemRefType> memrefType = getBufferType(value, options, state);
- if (failed(memrefType))
+ FailureOr<BufferLikeType> bufferType = getBufferType(value, options, state);
+ if (failed(bufferType))
return failure();
- ensureToBufferOpIsValid(value, *memrefType);
+ ensureToBufferOpIsValid(value, *bufferType);
return rewriter
- .create<bufferization::ToBufferOp>(value.getLoc(), *memrefType, value)
+ .create<bufferization::ToBufferOp>(value.getLoc(), *bufferType, value)
.getResult();
}
/// Return the buffer type for a given Value (tensor) after bufferization.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
bufferization::getBufferType(Value value, const BufferizationOptions &options,
const BufferizationState &state) {
SmallVector<Value> invocationStack;
@@ -702,11 +703,11 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
}
/// Return the buffer type for a given Value (tensor) after bufferization.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
bufferization::getBufferType(Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) {
- assert(llvm::isa<TensorType>(value.getType()) &&
+ assert(llvm::isa<TensorLikeType>(value.getType()) &&
"unexpected non-tensor type");
invocationStack.push_back(value);
auto popFromStack =
@@ -718,13 +719,11 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
if (bufferizableOp)
return bufferizableOp.getBufferType(value, options, state, invocationStack);
- // Op is not bufferizable.
- auto memSpace =
- options.defaultMemorySpaceFn(cast<TensorType>(value.getType()));
- if (!memSpace.has_value())
- return op->emitError("could not infer memory space");
-
- return getMemRefType(value, options, /*layout=*/{}, *memSpace);
+ // Op is not bufferizable, use conversion interface.
+ bufferization::ConversionInterface iface(value.getContext());
+ return iface.getBufferType(value, options, state, [&](const Twine &message) {
+ return op->emitError(message);
+ });
}
bool bufferization::hasTensorSemantics(Operation *op) {
@@ -744,12 +743,11 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
SmallVector<Value> replacements;
for (OpResult opResult : op->getOpResults()) {
Value replacement = values[opResult.getResultNumber()];
- if (llvm::isa<TensorType>(opResult.getType())) {
+ if (llvm::isa<TensorLikeType>(opResult.getType())) {
// The OpResult is a tensor. Such values are replaced with memrefs during
// bufferization.
- assert((llvm::isa<MemRefType>(replacement.getType()) ||
- llvm::isa<UnrankedMemRefType>(replacement.getType())) &&
- "tensor op result should be replaced with a memref value");
+ assert(llvm::isa<BufferLikeType>(replacement.getType()) &&
+ "tensor op result should be replaced with a buffer value");
// The existing uses of the OpResult still expect a tensor. Insert a
// ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
// loose all of its users and eventually DCE away.
@@ -970,8 +968,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
// If the OpResult has an equivalent OpOperand, both OpResult and
// OpOperand bufferize to the exact same buffer type.
Value equivalentOperand = aliases.getAliases().front().opOperand->get();
- return getBufferType(equivalentOperand, options, bufferizationState,
- invocationStack);
+ return castToMemRef(getBufferType(equivalentOperand, options,
+ bufferizationState, invocationStack));
}
// If we do not know the memory space and there is no default memory space,
@@ -1031,7 +1029,7 @@ bufferization::detail::unknownGe...
[truncated]
|
@llvm/pr-subscribers-mlir-sparse Author: Andrei Golubev (andrey-golubev) ChangesFollowing the introduction of TensorLike and BufferLike type interfaces (see 00eaff3), introduce minimal changes required to bufferize a custom tensor operation into a custom buffer operation. To achieve this, a new conversion dialect interface is added that abstracts away the differences between existing (tensor -> memref) and custom conversions. The scope of the changes is intentionally limited (for example, BufferizableOpInterface is untouched) in order to first understand the basics and reach consensus design-wise. Patch is 49.87 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/142986.diff 19 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index adccbef754ec5..8390da956444d 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -17,6 +17,7 @@
#include <optional>
#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
namespace mlir {
class OpBuilder;
@@ -615,7 +616,7 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
/// IR, this function can be used.
///
/// This function is a wrapper around BufferizableOpInterface::getBufferType.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
const BufferizationOptions &options,
const BufferizationState &state);
@@ -629,7 +630,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
/// IR, this function can be used.
///
/// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack);
@@ -738,6 +739,18 @@ AliasingValueList unknownGetAliasingValues(OpOperand &opOperand);
/// This is the default implementation of
/// BufferizableOpInterface::hasTensorSemantics
bool defaultHasTensorSemantics(Operation *op);
+
+/// This is a helper function used when buffer type is guaranteed to be memref.
+FailureOr<BaseMemRefType> castToMemRef(FailureOr<BufferLikeType> bufferType);
+
+/// This function is a free-standing helper that relies on
+/// bufferization::ConversionInterface to verify the types in tensor and buffer
+/// worlds match.
+bool typesMatchAfterBufferization(Operation &op, Value tensor, Value buffer);
+
+/// This function is a free-standing helper that relies on
+/// bufferization::ConversionInterface to perform the conversion.
+Type getTensorFromBuffer(Type buffer);
} // namespace detail
} // namespace bufferization
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h
new file mode 100644
index 0000000000000..4164d1dcb9ea6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h
@@ -0,0 +1,72 @@
+//===- BufferizationConversionInterface.h - Dialect Interface ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONCONVERSIONINTERFACE_H_
+#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONCONVERSIONINTERFACE_H_
+
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
+#include "mlir/IR/DialectInterface.h"
+
+namespace mlir {
+namespace bufferization {
+
+/// This class defines a virtual interface for conversions between tensor-like
+/// and buffer-like types.
+struct ConversionDialectInterface
+ : DialectInterface::Base<ConversionDialectInterface> {
+ using Base::Base;
+
+ /// Hook to customize tensor-like -> buffer-like conversion within a given
+ /// dialect. Returns a buffer-like type for the specific tensor-like type.
+ virtual FailureOr<BufferLikeType> getBufferType(
+ Value value, const BufferizationOptions &options,
+ const BufferizationState &state,
+ function_ref<InFlightDiagnostic(const Twine &)> emitError) const = 0;
+
+ /// Hook to customize type checking between tensor-like and buffer-like types.
+ /// Given tensor `T` and buffer `B = getBufferType(T, ...)`, the call to
+ /// `typesMatch(T, B)` must return true.
+ virtual LogicalResult typesMatch(
+ TensorLikeType tensor, BufferLikeType buffer,
+ function_ref<InFlightDiagnostic(const Twine &)> emitError) const = 0;
+
+ /// Hook to customize buffer-like -> tensor-like conversion, which is the
+ /// opposite of bufferization.
+ virtual TensorLikeType getTensorFromBuffer(BufferLikeType buffer) const = 0;
+};
+
+/// Interface collection for conversion between tensor-like and buffer-like
+/// types, dispatches to a concrete interface implementation based on the
+/// dialect to which the given type belongs.
+struct ConversionInterface
+ : DialectInterfaceCollection<ConversionDialectInterface> {
+ using Base::Base;
+
+ /// Dispatches to ConversionDialectInterface::getBufferType() of the dialect
+ /// associated with the value type.
+ FailureOr<BufferLikeType> getBufferType(
+ Value value, const BufferizationOptions &options,
+ const BufferizationState &state,
+ function_ref<InFlightDiagnostic(const Twine &)> emitError) const;
+
+ /// Dispatches to ConversionDialectInterface::typesMatch() of the dialect
+ /// associated with the value type.
+ LogicalResult
+ typesMatch(TensorLikeType tensor, BufferLikeType buffer,
+ function_ref<InFlightDiagnostic(const Twine &)> emitError) const;
+
+ /// Dispatches to ConversionDialectInterface::getTensorFromBuffer() of the
+ /// dialect associated with the value type.
+ TensorLikeType getTensorFromBuffer(BufferLikeType buffer) const;
+};
+
+} // namespace bufferization
+} // namespace mlir
+
+#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 3d4dcdee2663b..277d56bc3f647 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -12,6 +12,7 @@
include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
+include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -387,20 +388,28 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
// ToTensorOp
//===----------------------------------------------------------------------===//
+class Bufferization_TensorAndBufferMatch<string tensor, string buffer> : PredOpTrait<
+ "specified tensor and buffer types match",
+ CPred<
+ "::mlir::bufferization::detail::typesMatchAfterBufferization("
+ "$_op, $" # tensor # ", $" # buffer #")"
+ >
+>;
+
def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
BufferizableOpInterface,
SameOperandsAndResultShape,
SameOperandsAndResultElementType,
- AllElementTypesMatch<["memref", "result"]>
+ Bufferization_TensorAndBufferMatch<"result", "buffer">
]> {
- let summary = "create a tensor from a `memref`";
+ let summary = "create a buffer-like type from a tensor-like type";
let description = [{
- An operation that creates a tensor from a `memref`. The result value is a
- tensor whose shape and element type match the memref operand.
+ An operation that creates a tensor from a buffer. The result value is a
+ tensor-like type whose shape and element type match the buffer-like operand.
The opposite of this op is `to_buffer`. Together, these two ops are
useful for source/target materializations when doing type conversions
- involving tensors and memrefs.
+ involving tensors and buffers.
Example:
@@ -442,11 +451,11 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
away. However, such IR is no longer bufferizable with One-Shot Bufferize.
}];
- let arguments = (ins Arg<AnyRankedOrUnrankedMemRef,
+ let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface,
"the reference to load from",
- [MemReadAt<0, FullEffect>]>:$memref,
+ [MemReadAt<0, FullEffect>]>:$buffer,
UnitAttr:$restrict, UnitAttr:$writable);
- let results = (outs AnyTensor:$result);
+ let results = (outs Bufferization_TensorLikeTypeInterface:$result);
let extraClassDeclaration = [{
/// The result of a to_tensor is always a tensor.
@@ -473,19 +482,19 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
FailureOr<BaseMemRefType> getBufferType(
Value value, const BufferizationOptions &options,
const BufferizationState &state, SmallVector<Value> &invocationStack) {
- return ::llvm::cast<BaseMemRefType>(getMemref().getType());
+ return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
}
}];
let assemblyFormat = [{
- $memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
- `:` type($memref) `to` type($result)
+ $buffer (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
+ `:` type($buffer) `to` type($result)
}];
let builders = [
- OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
- auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
- build($_builder, $_state, rtt, memref, restrict, writeable);
+ OpBuilder<(ins "Value":$buffer, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
+ auto rtt = bufferization::detail::getTensorFromBuffer(buffer.getType());
+ build($_builder, $_state, rtt, buffer, restrict, writeable);
}]>
];
@@ -503,10 +512,9 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
SameOperandsAndResultShape,
SameOperandsAndResultElementType,
Pure,
- AllShapesMatch<["memref", "tensor"]>,
- AllElementTypesMatch<["memref", "tensor"]>
+ Bufferization_TensorAndBufferMatch<"tensor", "buffer">
]> {
- let summary = "cast a tensor to memref";
+ let summary = "cast a tensor-like type to buffer-like type";
let description = [{
An operation that returns the future buffer of a `tensor`.
@@ -524,8 +532,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
the returned buffer) will not be written to.
}];
- let arguments = (ins AnyTensor:$tensor, UnitAttr:$read_only);
- let results = (outs AnyRankedOrUnrankedMemRef:$memref);
+ let arguments = (ins Bufferization_TensorLikeTypeInterface:$tensor, UnitAttr:$read_only);
+ let results = (outs Bufferization_BufferLikeTypeInterface:$buffer);
let extraClassDeclaration = [{
//===------------------------------------------------------------------===//
@@ -560,7 +568,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
}];
let assemblyFormat = [{
- $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($memref)
+ $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($buffer)
}];
let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index a441b8b66659e..f56c10555f02c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -65,12 +65,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
// The operand was already bufferized. Take its type directly.
callerType = memrefType;
} else {
- FailureOr<BaseMemRefType> maybeCallerType =
+ FailureOr<BufferLikeType> maybeCallerType =
bufferization::getBufferType(opOperand->get(), options, state,
invocationStack);
if (failed(maybeCallerType))
return failure();
- callerType = *maybeCallerType;
+ assert(isa<BaseMemRefType>(*maybeCallerType) && "expected memref type");
+ callerType = cast<BaseMemRefType>(*maybeCallerType);
}
if (!bufferType) {
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index a57d58ab28d28..021a557f68b4b 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -164,8 +164,8 @@ struct SelectOpInterface
// buffers have different types, they differ only in their layout map. Cast
// both of them to the most dynamic MemRef type.
if (trueBuffer.getType() != falseBuffer.getType()) {
- auto targetType =
- bufferization::getBufferType(selectOp.getResult(), options, state);
+ auto targetType = bufferization::detail::castToMemRef(
+ bufferization::getBufferType(selectOp.getResult(), options, state));
if (failed(targetType))
return failure();
if (trueBuffer.getType() != *targetType)
@@ -187,10 +187,12 @@ struct SelectOpInterface
SmallVector<Value> &invocationStack) const {
auto selectOp = cast<arith::SelectOp>(op);
assert(value == selectOp.getResult() && "invalid value");
- auto trueType = bufferization::getBufferType(
- selectOp.getTrueValue(), options, state, invocationStack);
- auto falseType = bufferization::getBufferType(
- selectOp.getFalseValue(), options, state, invocationStack);
+ auto trueType =
+ bufferization::detail::castToMemRef(bufferization::getBufferType(
+ selectOp.getTrueValue(), options, state, invocationStack));
+ auto falseType =
+ bufferization::detail::castToMemRef(bufferization::getBufferType(
+ selectOp.getFalseValue(), options, state, invocationStack));
if (failed(trueType) || failed(falseType))
return failure();
if (*trueType == *falseType)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 1d6e1bdaf80f5..d00605a7b9a17 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -211,8 +212,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
// Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
if (copy)
return allocTensorOp.getResult();
- FailureOr<BaseMemRefType> copyBufferType =
- getBufferType(tensor, options, state);
+ auto copyBufferType =
+ detail::castToMemRef(getBufferType(tensor, options, state));
if (failed(copyBufferType))
return failure();
std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
@@ -673,28 +674,28 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
const BufferizationOptions &options,
const BufferizationState &state) {
#ifndef NDEBUG
- auto tensorType = llvm::dyn_cast<TensorType>(value.getType());
+ auto tensorType = llvm::dyn_cast<TensorLikeType>(value.getType());
assert(tensorType && "unexpected non-tensor type");
#endif // NDEBUG
// Replace "%t = to_tensor %m" with %m.
if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
- return toTensorOp.getMemref();
+ return toTensorOp.getBuffer();
// Insert to_buffer op.
OpBuilder::InsertionGuard g(rewriter);
setInsertionPointAfter(rewriter, value);
- FailureOr<BaseMemRefType> memrefType = getBufferType(value, options, state);
- if (failed(memrefType))
+ FailureOr<BufferLikeType> bufferType = getBufferType(value, options, state);
+ if (failed(bufferType))
return failure();
- ensureToBufferOpIsValid(value, *memrefType);
+ ensureToBufferOpIsValid(value, *bufferType);
return rewriter
- .create<bufferization::ToBufferOp>(value.getLoc(), *memrefType, value)
+ .create<bufferization::ToBufferOp>(value.getLoc(), *bufferType, value)
.getResult();
}
/// Return the buffer type for a given Value (tensor) after bufferization.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
bufferization::getBufferType(Value value, const BufferizationOptions &options,
const BufferizationState &state) {
SmallVector<Value> invocationStack;
@@ -702,11 +703,11 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
}
/// Return the buffer type for a given Value (tensor) after bufferization.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
bufferization::getBufferType(Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) {
- assert(llvm::isa<TensorType>(value.getType()) &&
+ assert(llvm::isa<TensorLikeType>(value.getType()) &&
"unexpected non-tensor type");
invocationStack.push_back(value);
auto popFromStack =
@@ -718,13 +719,11 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
if (bufferizableOp)
return bufferizableOp.getBufferType(value, options, state, invocationStack);
- // Op is not bufferizable.
- auto memSpace =
- options.defaultMemorySpaceFn(cast<TensorType>(value.getType()));
- if (!memSpace.has_value())
- return op->emitError("could not infer memory space");
-
- return getMemRefType(value, options, /*layout=*/{}, *memSpace);
+ // Op is not bufferizable, use conversion interface.
+ bufferization::ConversionInterface iface(value.getContext());
+ return iface.getBufferType(value, options, state, [&](const Twine &message) {
+ return op->emitError(message);
+ });
}
bool bufferization::hasTensorSemantics(Operation *op) {
@@ -744,12 +743,11 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
SmallVector<Value> replacements;
for (OpResult opResult : op->getOpResults()) {
Value replacement = values[opResult.getResultNumber()];
- if (llvm::isa<TensorType>(opResult.getType())) {
+ if (llvm::isa<TensorLikeType>(opResult.getType())) {
// The OpResult is a tensor. Such values are replaced with memrefs during
// bufferization.
- assert((llvm::isa<MemRefType>(replacement.getType()) ||
- llvm::isa<UnrankedMemRefType>(replacement.getType())) &&
- "tensor op result should be replaced with a memref value");
+ assert(llvm::isa<BufferLikeType>(replacement.getType()) &&
+ "tensor op result should be replaced with a buffer value");
// The existing uses of the OpResult still expect a tensor. Insert a
// ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
// loose all of its users and eventually DCE away.
@@ -970,8 +968,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
// If the OpResult has an equivalent OpOperand, both OpResult and
// OpOperand bufferize to the exact same buffer type.
Value equivalentOperand = aliases.getAliases().front().opOperand->get();
- return getBufferType(equivalentOperand, options, bufferizationState,
- invocationStack);
+ return castToMemRef(getBufferType(equivalentOperand, options,
+ bufferizationState, invocationStack));
}
// If we do not know the memory space and there is no default memory space,
@@ -1031,7 +1029,7 @@ bufferization::detail::unknownGe...
[truncated]
|
@christopherbate I think you might find this interesting also |
mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many thanks for the upstream contribution! I've seen some changes like this that'd break downstream projects. It is very reasonable because you're improving the bufferization.
I'd appreciate it if you can list down some naming changes in the PR description before you land the PR. It'd save us some time to look at all the details. E.g., the operand name of ToTensorOp
becomes buffer
; the return type of bufferization::getBufferType
becomes FailureOr<BufferLikeType>
, etc;
(It is definitely not a requirement, and it is okay if you don't do it. I appreciate your contribution. Thanks again! 🙏)
Not sure this is good PR description-wise but I mentioned the two changes (actually, what you mention is probably the only two here?). |
@matthias-springer (gently pinging) |
auto rtt = memref::getTensorTypeFromMemRefType(memref.getType()); | ||
build($_builder, $_state, rtt, memref, restrict, writeable); | ||
OpBuilder<(ins "Value":$buffer, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{ | ||
auto rtt = bufferization::detail::getTensorFromBuffer(buffer.getType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@matthias-springer I took a look at the current implementation. I think there's a problem with this API if we switch to a BufferLike's function:
- ToTensor seems to infer the tensor type from buffer type
- Without options here, there's no (good) way to customize the (reverse) bufferization behavior for builtins
Should this be changed somehow? The things I could think of:
- Drop type inference and make ToTensor always be constructed with an explicit type (so, user has to care about reverse bufferization)
- Pass bufferization options to the builder
- Assume that the current API is the way it is -> meaning that options cannot be used as customization point for builtins also and a different mechanism is needed
- something else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we go with the above-mentioned verifyCompatibleTensorType
approach, type inference must be dropped, indeed. I think that's alright. We can add some helper functions to make it easy for folks to migrate their code.
I wouldn't pass the BufferizationOptions
to the op builder. I think it's better to explicitly specify the result type when constructing a to_tensor / to_buffer op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we go with the above-mentioned verifyCompatibleTensorType approach, type inference must be dropped, indeed. I think that's alright. We can add some helper functions to make it easy for folks to migrate their code.
Turns out this builder was such a helper function. In fact, many places seem to just be able to provide the valid tensor type without additional reverse conversions.
class Bufferization_TensorAndBufferMatch<string tensor, string buffer> : PredOpTrait< | ||
"specified tensor and buffer types match", | ||
CPred< | ||
"::mlir::bufferization::detail::typesMatchAfterBufferization(" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@matthias-springer this would be the other problematic place. With the current approach, we want to validate that bufferization is "valid" on a tensor <-> buffer level.
The current logic checks tensor.getShape() == buffer.getShape()
and tensor.getElementType() == buffer.getElementType()
this practically means that TensorLike and BufferLike are ShapedType (fine by me), but even that is not enough. We've recently started to experiment with shape bounds and dynamic shapes -> getShape checking might not be sufficient.
Instead, I think we should either restore the old comparison logic (which was changed in ced2fc7) or - more likely - have this put into an interface so that it's a customization point.
But then, which interface? TensorLike? BufferLike? Since it's a type matching function, it's kind of valid to be in both.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about some kind of double dispatch? E.g., for to_tensor
:
- Op verifier checks that operand type implements
BufferLikeTypeInterface
. - Op verifier checks that result type implements
TensorLikeTypeInterface
. - Op verifier calls
bufferLikeType.verifyCompatibleTensorType(tensorLikeType)
. The result isstd::optional<LogicalResult>
.success
means that the type is compatible,failure
means that it is incompatible,nullopt
means that we don't know. - In case of
nullopt
, calltensorLikeType.verifyCompatibleBufferType(bufferLikeType)
. - If the result is still
nullopt
, we fail verification because neither of the two types "know" each other.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wouldn't go with double dispatch to be honest because there's majorly no difference between "tensor equivalent to buffer" vs "buffer equivalent to tensor" (we have both things which do not change between the two calls). For the time being, I guess we just put it somewhere? (either to buffer-like or to tensor-like). Perhaps with more changes it would be clearer what to do here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sry, double dispatch is the wrong name. It's more like querying both interfaces.
The reason why I'm suggesting this is to support custom conversions for builtin types. The type interface implementation of RankedTensorType
won't know about your custom buffer type, so it cannot verify type compatibility. But the type interface implementation of your custom buffer type can do the verification.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see what you mean now. It's actually more (or less?) straightforward:
What we have is:
custom.tensor
->custom.buffer
<-- this is "problem-free" (because type interface implementation is on us)tensor<..., {custom encoding}>
->memref<..., {custom layout}>
<-- this is meh
there's never really a situation where we'd bufferize builtin into non-builtin or non-builtin into builtin, but the case is interesting (I'd perhaps add support for this separately if that has any use).
edit: I guess for 2. we can keep the "default" logic - which is what's on main
right now - and then let's see where it brings us.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implemented the half of the querying for TensorLikeType.
With the last update,
Is it worth splitting these two into separate patches? (Maybe also separate "memref" -> "buffer" renames in ToTensor/ToBuffer ops, etc.). Alternatively, I guess I can just update the PR description to mention these changes. (Edit: updated the PR description for now) |
LGTM for the sparse changes (which are rather mechanical) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall!
@@ -267,7 +268,7 @@ struct BufferizationOptions { | |||
/// Tensor -> MemRef type converter. | |||
/// Parameters: Value, memory space, bufferization options | |||
using UnknownTypeConverterFn = std::function<BaseMemRefType( | |||
Value, Attribute memorySpace, const BufferizationOptions &)>; | |||
TensorType, Attribute memorySpace, const BufferizationOptions &)>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did this change? Can we keep Value
here to keep the number of API changes small?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need this together with getMemRefType()
API change since TensorLike::getBufferType() works with types (there are no values). see https://github.com/llvm/llvm-project/pull/142986/files#diff-93e08b1e03259ffa2ff33aec1bd8f907f066e1d86cf83cf140fb031fb31c6f3aR71-R72:
llvm::FailureOr<BufferLikeType> getBufferType(mlir::Type tensor, // mlir::Type here now
const BufferizationOptions &options,
const BufferizationState &state,
llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
auto tensorType = cast<TensorType>(type);
// ...
return cast<BufferLikeType>(
getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could separate this into another patch if you prefer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this can be a separate PR, let's do that. Integrating bufferization-related changes are often quite difficult for downstream users, so I try to keep them as small as possible, so they can integrate one at a time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
separated - #144658
Value tensor, | ||
Value buffer) { | ||
assert(isa<TensorLikeType>(tensor.getType()) && "expected TensorLikeType"); | ||
assert(isa<BufferLikeType>(buffer.getType()) && "expected BufferLikeType"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: dyn_cast, then you don't need to cast a second time below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean something like: auto x = dyn_cast<Blah>(y)
then assert(x != nullptr)
?
I keep them separated like this since release builds would likely strip asserts and so one gets just a cheap llvm::cast<>
(no runtime checks afair) instead of dyn_cast
/isa
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think llvm::cast
still does the check and asserts. Actually, this means you can just remove the assertions entirely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
const BufferizationState &state, | ||
llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const { | ||
auto tensorType = cast<TensorType>(tensor); | ||
// Fall back to tensor -> memref conversion. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused about the comment. Why is this a "fall back"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
never mind, outdated comment that i missed when moving stuff around.
edit: removed the comments.
/*methodName=*/"getBufferType", | ||
/*args=*/(ins | ||
"const ::mlir::bufferization::BufferizationOptions &":$options, | ||
"const ::mlir::bufferization::BufferizationState &":$state, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, I'd like to avoid passing the BufferizationState here. Do we really need it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
honestly, i have no idea what this state thing is... I don't really need it myself, imho options is enough. I saw that getBufferType
accepts it now, so there's a chance it's actually used by something? (no clue). let me try to make some sense out of it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The state can used during bufferization to store symbol tables, etc. We don't need it during type conversions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome! removed the state from this API.
@@ -738,6 +740,14 @@ AliasingValueList unknownGetAliasingValues(OpOperand &opOperand); | |||
/// This is the default implementation of | |||
/// BufferizableOpInterface::hasTensorSemantics | |||
bool defaultHasTensorSemantics(Operation *op); | |||
|
|||
/// This is a helper function used when buffer type is guaranteed to be memref. | |||
FailureOr<BaseMemRefType> castToMemRef(FailureOr<BufferLikeType> bufferType); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Can you rename this to asMemRefType
? The term "cast" is overloaded, it could refer to memref.cast
. Let's also document the exact behavior of the function: failure -> failure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
An operation that creates a tensor from a `memref`. The result value is a | ||
tensor whose shape and element type match the memref operand. | ||
An operation that creates a tensor from a buffer. The result value is a | ||
tensor-like type whose shape and element type match the buffer-like operand. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mention in the documentation that operand and result types must be compatible as per TensorLikeTypeInterface::verifyCompatibleBufferType
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point. rewrote this piece.
Following the introduction of TensorLike and BufferLike type interfaces (see 00eaff3), introduce minimal changes required to bufferize a custom tensor operation into a custom buffer operation. To achieve this, a new conversion dialect interface is added that abstracts away the differences between existing (tensor -> memref) and custom conversions. The scope of the changes is intentionally limited (for example, BufferizableOpInterface is untouched) in order to first understand the basics and reach consensus design-wise.
The builder is ambiguous given customizable tensor-like -> buffer-like conversion and is thus removed. The places where reverse bufferization has to happen rely on the pre-existing functionality.
Noteworthy changes: * bufferization::getMemRefType() accepts a TensorType instead of Value to achieve broader applicability * BufferizationOptions::UnknownTypeConverterFn accepts a TensorType instead of Value to allow it being used in the updated getMemRefType()
@matthias-springer do you mind merging this also (again, I have no rights :|) or should we wait for some more reviews/feedback? I only see Aart's LGTM for sparsity-related changes. |
… didnt cover everything)
#144743) … cover everything)
This reverts commit ee070d0.
…lvm#144721 didnt… (llvm#144743)" This reverts commit 2a41350.
This reverts commit ee070d0.
…lvm#144721 didnt… (llvm#144743)" This reverts commit 2a41350.
This reverts commit ee070d0.
…lvm#144721 didnt… (llvm#144743)" This reverts commit 2a41350.
This reverts commit ee070d0.
…lvm#144721 didnt… (llvm#144743)" This reverts commit 2a41350.
…lvm#144721 didnt… (llvm#144743)" This reverts commit 2a41350.
This reverts commit ee070d0.
…lvm#144721 didnt… (llvm#144743)" This reverts commit 2a41350.
This reverts commit ee070d0.
…lvm#144721 didnt… (llvm#144743)" This reverts commit 2a41350.
This reverts commit ee070d0.
…lvm#144721 didnt… (llvm#144743)" This reverts commit 2a41350.
This reverts commit ee070d0.
…lvm#144721 didnt… (llvm#144743)" This reverts commit 2a41350.
This reverts commit ee070d0.
…lvm#144721 didnt… (llvm#144743)" This reverts commit 2a41350.
Following the addition of TensorLike and BufferLike type interfaces (see 00eaff3), introduce minimal changes required to bufferize a custom tensor operation into a custom buffer operation.
To achieve this, new interface methods are added to TensorLike type interface that abstract away the differences between existing (tensor -> memref) and custom conversions.
The scope of the changes is intentionally limited (for example, BufferizableOpInterface is untouched) in order to first understand the basics and reach consensus design-wise.
Notable changes: