Skip to content

[mlir][bufferization] Support custom types (1/N) #142986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 18, 2025

Conversation

andrey-golubev
Copy link
Contributor

@andrey-golubev andrey-golubev commented Jun 5, 2025

Following the addition of TensorLike and BufferLike type interfaces (see 00eaff3), introduce minimal changes required to bufferize a custom tensor operation into a custom buffer operation.

To achieve this, new interface methods are added to TensorLike type interface that abstract away the differences between existing (tensor -> memref) and custom conversions.

The scope of the changes is intentionally limited (for example, BufferizableOpInterface is untouched) in order to first understand the basics and reach consensus design-wise.


Notable changes:

  • mlir::bufferization::getBufferType() returns BufferLikeType (instead of BaseMemRefType)
  • ToTensorOp / ToBufferOp operate on TensorLikeType / BufferLikeType. Operation argument "memref" renamed to "buffer"
  • ToTensorOp's tensor type inferring builder is dropped (users now need to provide the tensor type explicitly)

@llvmbot
Copy link
Member

llvmbot commented Jun 5, 2025

@llvm/pr-subscribers-mlir-shape
@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-arith
@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir-bufferization

Author: Andrei Golubev (andrey-golubev)

Changes

Following the introduction of TensorLike and BufferLike type interfaces (see 00eaff3), introduce minimal changes required to bufferize a custom tensor operation into a custom buffer operation.

To achieve this, a new conversion dialect interface is added that abstracts away the differences between existing (tensor -> memref) and custom conversions.

The scope of the changes is intentionally limited (for example, BufferizableOpInterface is untouched) in order to first understand the basics and reach consensus design-wise.


Patch is 49.87 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/142986.diff

19 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+15-2)
  • (added) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h (+72)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+28-20)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h (+3-2)
  • (modified) mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp (+8-6)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+51-25)
  • (added) mlir/lib/Dialect/Bufferization/IR/BufferizationConversionInterface.cpp (+67)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+11-10)
  • (modified) mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+4-4)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+4-4)
  • (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+28-23)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+8-6)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir (+20-1)
  • (modified) mlir/test/lib/Dialect/Test/TestDialect.cpp (+49)
  • (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+23)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.h (+1)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+56-2)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index adccbef754ec5..8390da956444d 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -17,6 +17,7 @@
 #include <optional>
 
 #include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
 
 namespace mlir {
 class OpBuilder;
@@ -615,7 +616,7 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
 /// IR, this function can be used.
 ///
 /// This function is a wrapper around BufferizableOpInterface::getBufferType.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
                                         const BufferizationOptions &options,
                                         const BufferizationState &state);
 
@@ -629,7 +630,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
 /// IR, this function can be used.
 ///
 /// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
                                         const BufferizationOptions &options,
                                         const BufferizationState &state,
                                         SmallVector<Value> &invocationStack);
@@ -738,6 +739,18 @@ AliasingValueList unknownGetAliasingValues(OpOperand &opOperand);
 /// This is the default implementation of
 /// BufferizableOpInterface::hasTensorSemantics
 bool defaultHasTensorSemantics(Operation *op);
+
+/// This is a helper function used when buffer type is guaranteed to be memref.
+FailureOr<BaseMemRefType> castToMemRef(FailureOr<BufferLikeType> bufferType);
+
+/// This function is a free-standing helper that relies on
+/// bufferization::ConversionInterface to verify the types in tensor and buffer
+/// worlds match.
+bool typesMatchAfterBufferization(Operation &op, Value tensor, Value buffer);
+
+/// This function is a free-standing helper that relies on
+/// bufferization::ConversionInterface to perform the conversion.
+Type getTensorFromBuffer(Type buffer);
 } // namespace detail
 
 } // namespace bufferization
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h
new file mode 100644
index 0000000000000..4164d1dcb9ea6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h
@@ -0,0 +1,72 @@
+//===- BufferizationConversionInterface.h - Dialect Interface ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONCONVERSIONINTERFACE_H_
+#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONCONVERSIONINTERFACE_H_
+
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
+#include "mlir/IR/DialectInterface.h"
+
+namespace mlir {
+namespace bufferization {
+
+/// This class defines a virtual interface for conversions between tensor-like
+/// and buffer-like types.
+struct ConversionDialectInterface
+    : DialectInterface::Base<ConversionDialectInterface> {
+  using Base::Base;
+
+  /// Hook to customize tensor-like -> buffer-like conversion within a given
+  /// dialect. Returns a buffer-like type for the specific tensor-like type.
+  virtual FailureOr<BufferLikeType> getBufferType(
+      Value value, const BufferizationOptions &options,
+      const BufferizationState &state,
+      function_ref<InFlightDiagnostic(const Twine &)> emitError) const = 0;
+
+  /// Hook to customize type checking between tensor-like and buffer-like types.
+  /// Given tensor `T` and buffer `B = getBufferType(T, ...)`, the call to
+  /// `typesMatch(T, B)` must return true.
+  virtual LogicalResult typesMatch(
+      TensorLikeType tensor, BufferLikeType buffer,
+      function_ref<InFlightDiagnostic(const Twine &)> emitError) const = 0;
+
+  /// Hook to customize buffer-like -> tensor-like conversion, which is the
+  /// opposite of bufferization.
+  virtual TensorLikeType getTensorFromBuffer(BufferLikeType buffer) const = 0;
+};
+
+/// Interface collection for conversion between tensor-like and buffer-like
+/// types, dispatches to a concrete interface implementation based on the
+/// dialect to which the given type belongs.
+struct ConversionInterface
+    : DialectInterfaceCollection<ConversionDialectInterface> {
+  using Base::Base;
+
+  /// Dispatches to ConversionDialectInterface::getBufferType() of the dialect
+  /// associated with the value type.
+  FailureOr<BufferLikeType> getBufferType(
+      Value value, const BufferizationOptions &options,
+      const BufferizationState &state,
+      function_ref<InFlightDiagnostic(const Twine &)> emitError) const;
+
+  /// Dispatches to ConversionDialectInterface::typesMatch() of the dialect
+  /// associated with the value type.
+  LogicalResult
+  typesMatch(TensorLikeType tensor, BufferLikeType buffer,
+             function_ref<InFlightDiagnostic(const Twine &)> emitError) const;
+
+  /// Dispatches to ConversionDialectInterface::getTensorFromBuffer() of the
+  /// dialect associated with the value type.
+  TensorLikeType getTensorFromBuffer(BufferLikeType buffer) const;
+};
+
+} // namespace bufferization
+} // namespace mlir
+
+#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 3d4dcdee2663b..277d56bc3f647 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -12,6 +12,7 @@
 include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
+include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
 include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -387,20 +388,28 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
 // ToTensorOp
 //===----------------------------------------------------------------------===//
 
+class Bufferization_TensorAndBufferMatch<string tensor, string buffer> : PredOpTrait<
+  "specified tensor and buffer types match",
+  CPred<
+    "::mlir::bufferization::detail::typesMatchAfterBufferization("
+        "$_op, $" # tensor # ", $" # buffer #")"
+  >
+>;
+
 def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     BufferizableOpInterface,
     SameOperandsAndResultShape,
     SameOperandsAndResultElementType,
-    AllElementTypesMatch<["memref", "result"]>
+    Bufferization_TensorAndBufferMatch<"result", "buffer">
   ]> {
-  let summary = "create a tensor from a `memref`";
+  let summary = "create a buffer-like type from a tensor-like type";
   let description = [{
-    An operation that creates a tensor from a `memref`. The result value is a
-    tensor whose shape and element type match the memref operand.
+    An operation that creates a tensor from a buffer. The result value is a
+    tensor-like type whose shape and element type match the buffer-like operand.
 
     The opposite of this op is `to_buffer`. Together, these two ops are
     useful for source/target materializations when doing type conversions
-    involving tensors and memrefs.
+    involving tensors and buffers.
 
     Example:
 
@@ -442,11 +451,11 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     away. However, such IR is no longer bufferizable with One-Shot Bufferize.
   }];
 
-  let arguments = (ins Arg<AnyRankedOrUnrankedMemRef,
+  let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface,
                            "the reference to load from",
-                           [MemReadAt<0, FullEffect>]>:$memref,
+                           [MemReadAt<0, FullEffect>]>:$buffer,
                        UnitAttr:$restrict, UnitAttr:$writable);
-  let results = (outs AnyTensor:$result);
+  let results = (outs Bufferization_TensorLikeTypeInterface:$result);
 
   let extraClassDeclaration = [{
     /// The result of a to_tensor is always a tensor.
@@ -473,19 +482,19 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     FailureOr<BaseMemRefType> getBufferType(
         Value value, const BufferizationOptions &options,
         const BufferizationState &state, SmallVector<Value> &invocationStack) {
-      return ::llvm::cast<BaseMemRefType>(getMemref().getType());
+      return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
     }
   }];
 
   let assemblyFormat = [{
-    $memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
-      `:` type($memref) `to` type($result)
+    $buffer (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
+      `:` type($buffer) `to` type($result)
   }];
 
   let builders = [
-    OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
-      auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
-      build($_builder, $_state, rtt, memref, restrict, writeable);
+    OpBuilder<(ins "Value":$buffer, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
+      auto rtt = bufferization::detail::getTensorFromBuffer(buffer.getType());
+      build($_builder, $_state, rtt, buffer, restrict, writeable);
     }]>
   ];
 
@@ -503,10 +512,9 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
     SameOperandsAndResultShape,
     SameOperandsAndResultElementType,
     Pure,
-    AllShapesMatch<["memref", "tensor"]>,
-    AllElementTypesMatch<["memref", "tensor"]>
+    Bufferization_TensorAndBufferMatch<"tensor", "buffer">
   ]> {
-  let summary = "cast a tensor to memref";
+  let summary = "cast a tensor-like type to buffer-like type";
   let description = [{
     An operation that returns the future buffer of a `tensor`.
 
@@ -524,8 +532,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
     the returned buffer) will not be written to.
   }];
 
-  let arguments = (ins AnyTensor:$tensor, UnitAttr:$read_only);
-  let results = (outs AnyRankedOrUnrankedMemRef:$memref);
+  let arguments = (ins Bufferization_TensorLikeTypeInterface:$tensor, UnitAttr:$read_only);
+  let results = (outs Bufferization_BufferLikeTypeInterface:$buffer);
 
   let extraClassDeclaration = [{
     //===------------------------------------------------------------------===//
@@ -560,7 +568,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
   }];
 
   let assemblyFormat = [{
-    $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($memref)
+    $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($buffer)
   }];
 
   let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index a441b8b66659e..f56c10555f02c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -65,12 +65,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
         // The operand was already bufferized. Take its type directly.
         callerType = memrefType;
       } else {
-        FailureOr<BaseMemRefType> maybeCallerType =
+        FailureOr<BufferLikeType> maybeCallerType =
             bufferization::getBufferType(opOperand->get(), options, state,
                                          invocationStack);
         if (failed(maybeCallerType))
           return failure();
-        callerType = *maybeCallerType;
+        assert(isa<BaseMemRefType>(*maybeCallerType) && "expected memref type");
+        callerType = cast<BaseMemRefType>(*maybeCallerType);
       }
 
       if (!bufferType) {
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index a57d58ab28d28..021a557f68b4b 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -164,8 +164,8 @@ struct SelectOpInterface
     // buffers have different types, they differ only in their layout map. Cast
     // both of them to the most dynamic MemRef type.
     if (trueBuffer.getType() != falseBuffer.getType()) {
-      auto targetType =
-          bufferization::getBufferType(selectOp.getResult(), options, state);
+      auto targetType = bufferization::detail::castToMemRef(
+          bufferization::getBufferType(selectOp.getResult(), options, state));
       if (failed(targetType))
         return failure();
       if (trueBuffer.getType() != *targetType)
@@ -187,10 +187,12 @@ struct SelectOpInterface
                 SmallVector<Value> &invocationStack) const {
     auto selectOp = cast<arith::SelectOp>(op);
     assert(value == selectOp.getResult() && "invalid value");
-    auto trueType = bufferization::getBufferType(
-        selectOp.getTrueValue(), options, state, invocationStack);
-    auto falseType = bufferization::getBufferType(
-        selectOp.getFalseValue(), options, state, invocationStack);
+    auto trueType =
+        bufferization::detail::castToMemRef(bufferization::getBufferType(
+            selectOp.getTrueValue(), options, state, invocationStack));
+    auto falseType =
+        bufferization::detail::castToMemRef(bufferization::getBufferType(
+            selectOp.getFalseValue(), options, state, invocationStack));
     if (failed(trueType) || failed(falseType))
       return failure();
     if (*trueType == *falseType)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 1d6e1bdaf80f5..d00605a7b9a17 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -211,8 +212,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
   // Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
   if (copy)
     return allocTensorOp.getResult();
-  FailureOr<BaseMemRefType> copyBufferType =
-      getBufferType(tensor, options, state);
+  auto copyBufferType =
+      detail::castToMemRef(getBufferType(tensor, options, state));
   if (failed(copyBufferType))
     return failure();
   std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
@@ -673,28 +674,28 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
                                           const BufferizationOptions &options,
                                           const BufferizationState &state) {
 #ifndef NDEBUG
-  auto tensorType = llvm::dyn_cast<TensorType>(value.getType());
+  auto tensorType = llvm::dyn_cast<TensorLikeType>(value.getType());
   assert(tensorType && "unexpected non-tensor type");
 #endif // NDEBUG
 
   // Replace "%t = to_tensor %m" with %m.
   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
-    return toTensorOp.getMemref();
+    return toTensorOp.getBuffer();
 
   // Insert to_buffer op.
   OpBuilder::InsertionGuard g(rewriter);
   setInsertionPointAfter(rewriter, value);
-  FailureOr<BaseMemRefType> memrefType = getBufferType(value, options, state);
-  if (failed(memrefType))
+  FailureOr<BufferLikeType> bufferType = getBufferType(value, options, state);
+  if (failed(bufferType))
     return failure();
-  ensureToBufferOpIsValid(value, *memrefType);
+  ensureToBufferOpIsValid(value, *bufferType);
   return rewriter
-      .create<bufferization::ToBufferOp>(value.getLoc(), *memrefType, value)
+      .create<bufferization::ToBufferOp>(value.getLoc(), *bufferType, value)
       .getResult();
 }
 
 /// Return the buffer type for a given Value (tensor) after bufferization.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 bufferization::getBufferType(Value value, const BufferizationOptions &options,
                              const BufferizationState &state) {
   SmallVector<Value> invocationStack;
@@ -702,11 +703,11 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
 }
 
 /// Return the buffer type for a given Value (tensor) after bufferization.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 bufferization::getBufferType(Value value, const BufferizationOptions &options,
                              const BufferizationState &state,
                              SmallVector<Value> &invocationStack) {
-  assert(llvm::isa<TensorType>(value.getType()) &&
+  assert(llvm::isa<TensorLikeType>(value.getType()) &&
          "unexpected non-tensor type");
   invocationStack.push_back(value);
   auto popFromStack =
@@ -718,13 +719,11 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
   if (bufferizableOp)
     return bufferizableOp.getBufferType(value, options, state, invocationStack);
 
-  // Op is not bufferizable.
-  auto memSpace =
-      options.defaultMemorySpaceFn(cast<TensorType>(value.getType()));
-  if (!memSpace.has_value())
-    return op->emitError("could not infer memory space");
-
-  return getMemRefType(value, options, /*layout=*/{}, *memSpace);
+  // Op is not bufferizable, use conversion interface.
+  bufferization::ConversionInterface iface(value.getContext());
+  return iface.getBufferType(value, options, state, [&](const Twine &message) {
+    return op->emitError(message);
+  });
 }
 
 bool bufferization::hasTensorSemantics(Operation *op) {
@@ -744,12 +743,11 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
   SmallVector<Value> replacements;
   for (OpResult opResult : op->getOpResults()) {
     Value replacement = values[opResult.getResultNumber()];
-    if (llvm::isa<TensorType>(opResult.getType())) {
+    if (llvm::isa<TensorLikeType>(opResult.getType())) {
       // The OpResult is a tensor. Such values are replaced with memrefs during
       // bufferization.
-      assert((llvm::isa<MemRefType>(replacement.getType()) ||
-              llvm::isa<UnrankedMemRefType>(replacement.getType())) &&
-             "tensor op result should be replaced with a memref value");
+      assert(llvm::isa<BufferLikeType>(replacement.getType()) &&
+             "tensor op result should be replaced with a buffer value");
       // The existing uses of the OpResult still expect a tensor. Insert a
       // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
       // loose all of its users and eventually DCE away.
@@ -970,8 +968,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
     // If the OpResult has an equivalent OpOperand, both OpResult and
     // OpOperand bufferize to the exact same buffer type.
     Value equivalentOperand = aliases.getAliases().front().opOperand->get();
-    return getBufferType(equivalentOperand, options, bufferizationState,
-                         invocationStack);
+    return castToMemRef(getBufferType(equivalentOperand, options,
+                                      bufferizationState, invocationStack));
   }
 
   // If we do not know the memory space and there is no default memory space,
@@ -1031,7 +1029,7 @@ bufferization::detail::unknownGe...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jun 5, 2025

@llvm/pr-subscribers-mlir-sparse

Author: Andrei Golubev (andrey-golubev)

Changes

Following the introduction of TensorLike and BufferLike type interfaces (see 00eaff3), introduce minimal changes required to bufferize a custom tensor operation into a custom buffer operation.

To achieve this, a new conversion dialect interface is added that abstracts away the differences between existing (tensor -> memref) and custom conversions.

The scope of the changes is intentionally limited (for example, BufferizableOpInterface is untouched) in order to first understand the basics and reach consensus design-wise.


Patch is 49.87 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/142986.diff

19 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+15-2)
  • (added) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h (+72)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+28-20)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h (+3-2)
  • (modified) mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp (+8-6)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+51-25)
  • (added) mlir/lib/Dialect/Bufferization/IR/BufferizationConversionInterface.cpp (+67)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+11-10)
  • (modified) mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+4-4)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+4-4)
  • (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+28-23)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+8-6)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir (+20-1)
  • (modified) mlir/test/lib/Dialect/Test/TestDialect.cpp (+49)
  • (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+23)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.h (+1)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+56-2)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index adccbef754ec5..8390da956444d 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -17,6 +17,7 @@
 #include <optional>
 
 #include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
 
 namespace mlir {
 class OpBuilder;
@@ -615,7 +616,7 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
 /// IR, this function can be used.
 ///
 /// This function is a wrapper around BufferizableOpInterface::getBufferType.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
                                         const BufferizationOptions &options,
                                         const BufferizationState &state);
 
@@ -629,7 +630,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
 /// IR, this function can be used.
 ///
 /// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
                                         const BufferizationOptions &options,
                                         const BufferizationState &state,
                                         SmallVector<Value> &invocationStack);
@@ -738,6 +739,18 @@ AliasingValueList unknownGetAliasingValues(OpOperand &opOperand);
 /// This is the default implementation of
 /// BufferizableOpInterface::hasTensorSemantics
 bool defaultHasTensorSemantics(Operation *op);
+
+/// This is a helper function used when buffer type is guaranteed to be memref.
+FailureOr<BaseMemRefType> castToMemRef(FailureOr<BufferLikeType> bufferType);
+
+/// This function is a free-standing helper that relies on
+/// bufferization::ConversionInterface to verify the types in tensor and buffer
+/// worlds match.
+bool typesMatchAfterBufferization(Operation &op, Value tensor, Value buffer);
+
+/// This function is a free-standing helper that relies on
+/// bufferization::ConversionInterface to perform the conversion.
+Type getTensorFromBuffer(Type buffer);
 } // namespace detail
 
 } // namespace bufferization
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h
new file mode 100644
index 0000000000000..4164d1dcb9ea6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h
@@ -0,0 +1,72 @@
+//===- BufferizationConversionInterface.h - Dialect Interface ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONCONVERSIONINTERFACE_H_
+#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONCONVERSIONINTERFACE_H_
+
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
+#include "mlir/IR/DialectInterface.h"
+
+namespace mlir {
+namespace bufferization {
+
+/// This class defines a virtual interface for conversions between tensor-like
+/// and buffer-like types.
+struct ConversionDialectInterface
+    : DialectInterface::Base<ConversionDialectInterface> {
+  using Base::Base;
+
+  /// Hook to customize tensor-like -> buffer-like conversion within a given
+  /// dialect. Returns a buffer-like type for the specific tensor-like type.
+  virtual FailureOr<BufferLikeType> getBufferType(
+      Value value, const BufferizationOptions &options,
+      const BufferizationState &state,
+      function_ref<InFlightDiagnostic(const Twine &)> emitError) const = 0;
+
+  /// Hook to customize type checking between tensor-like and buffer-like types.
+  /// Given tensor `T` and buffer `B = getBufferType(T, ...)`, the call to
+  /// `typesMatch(T, B)` must return true.
+  virtual LogicalResult typesMatch(
+      TensorLikeType tensor, BufferLikeType buffer,
+      function_ref<InFlightDiagnostic(const Twine &)> emitError) const = 0;
+
+  /// Hook to customize buffer-like -> tensor-like conversion, which is the
+  /// opposite of bufferization.
+  virtual TensorLikeType getTensorFromBuffer(BufferLikeType buffer) const = 0;
+};
+
+/// Interface collection for conversion between tensor-like and buffer-like
+/// types, dispatches to a concrete interface implementation based on the
+/// dialect to which the given type belongs.
+struct ConversionInterface
+    : DialectInterfaceCollection<ConversionDialectInterface> {
+  using Base::Base;
+
+  /// Dispatches to ConversionDialectInterface::getBufferType() of the dialect
+  /// associated with the value type.
+  FailureOr<BufferLikeType> getBufferType(
+      Value value, const BufferizationOptions &options,
+      const BufferizationState &state,
+      function_ref<InFlightDiagnostic(const Twine &)> emitError) const;
+
+  /// Dispatches to ConversionDialectInterface::typesMatch() of the dialect
+  /// associated with the value type.
+  LogicalResult
+  typesMatch(TensorLikeType tensor, BufferLikeType buffer,
+             function_ref<InFlightDiagnostic(const Twine &)> emitError) const;
+
+  /// Dispatches to ConversionDialectInterface::getTensorFromBuffer() of the
+  /// dialect associated with the value type.
+  TensorLikeType getTensorFromBuffer(BufferLikeType buffer) const;
+};
+
+} // namespace bufferization
+} // namespace mlir
+
+#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 3d4dcdee2663b..277d56bc3f647 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -12,6 +12,7 @@
 include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
+include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
 include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -387,20 +388,28 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
 // ToTensorOp
 //===----------------------------------------------------------------------===//
 
+class Bufferization_TensorAndBufferMatch<string tensor, string buffer> : PredOpTrait<
+  "specified tensor and buffer types match",
+  CPred<
+    "::mlir::bufferization::detail::typesMatchAfterBufferization("
+        "$_op, $" # tensor # ", $" # buffer #")"
+  >
+>;
+
 def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     BufferizableOpInterface,
     SameOperandsAndResultShape,
     SameOperandsAndResultElementType,
-    AllElementTypesMatch<["memref", "result"]>
+    Bufferization_TensorAndBufferMatch<"result", "buffer">
   ]> {
-  let summary = "create a tensor from a `memref`";
+  let summary = "create a buffer-like type from a tensor-like type";
   let description = [{
-    An operation that creates a tensor from a `memref`. The result value is a
-    tensor whose shape and element type match the memref operand.
+    An operation that creates a tensor from a buffer. The result value is a
+    tensor-like type whose shape and element type match the buffer-like operand.
 
     The opposite of this op is `to_buffer`. Together, these two ops are
     useful for source/target materializations when doing type conversions
-    involving tensors and memrefs.
+    involving tensors and buffers.
 
     Example:
 
@@ -442,11 +451,11 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     away. However, such IR is no longer bufferizable with One-Shot Bufferize.
   }];
 
-  let arguments = (ins Arg<AnyRankedOrUnrankedMemRef,
+  let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface,
                            "the reference to load from",
-                           [MemReadAt<0, FullEffect>]>:$memref,
+                           [MemReadAt<0, FullEffect>]>:$buffer,
                        UnitAttr:$restrict, UnitAttr:$writable);
-  let results = (outs AnyTensor:$result);
+  let results = (outs Bufferization_TensorLikeTypeInterface:$result);
 
   let extraClassDeclaration = [{
     /// The result of a to_tensor is always a tensor.
@@ -473,19 +482,19 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     FailureOr<BaseMemRefType> getBufferType(
         Value value, const BufferizationOptions &options,
         const BufferizationState &state, SmallVector<Value> &invocationStack) {
-      return ::llvm::cast<BaseMemRefType>(getMemref().getType());
+      return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
     }
   }];
 
   let assemblyFormat = [{
-    $memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
-      `:` type($memref) `to` type($result)
+    $buffer (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
+      `:` type($buffer) `to` type($result)
   }];
 
   let builders = [
-    OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
-      auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
-      build($_builder, $_state, rtt, memref, restrict, writeable);
+    OpBuilder<(ins "Value":$buffer, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
+      auto rtt = bufferization::detail::getTensorFromBuffer(buffer.getType());
+      build($_builder, $_state, rtt, buffer, restrict, writeable);
     }]>
   ];
 
@@ -503,10 +512,9 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
     SameOperandsAndResultShape,
     SameOperandsAndResultElementType,
     Pure,
-    AllShapesMatch<["memref", "tensor"]>,
-    AllElementTypesMatch<["memref", "tensor"]>
+    Bufferization_TensorAndBufferMatch<"tensor", "buffer">
   ]> {
-  let summary = "cast a tensor to memref";
+  let summary = "cast a tensor-like type to buffer-like type";
   let description = [{
     An operation that returns the future buffer of a `tensor`.
 
@@ -524,8 +532,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
     the returned buffer) will not be written to.
   }];
 
-  let arguments = (ins AnyTensor:$tensor, UnitAttr:$read_only);
-  let results = (outs AnyRankedOrUnrankedMemRef:$memref);
+  let arguments = (ins Bufferization_TensorLikeTypeInterface:$tensor, UnitAttr:$read_only);
+  let results = (outs Bufferization_BufferLikeTypeInterface:$buffer);
 
   let extraClassDeclaration = [{
     //===------------------------------------------------------------------===//
@@ -560,7 +568,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
   }];
 
   let assemblyFormat = [{
-    $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($memref)
+    $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($buffer)
   }];
 
   let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index a441b8b66659e..f56c10555f02c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -65,12 +65,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
         // The operand was already bufferized. Take its type directly.
         callerType = memrefType;
       } else {
-        FailureOr<BaseMemRefType> maybeCallerType =
+        FailureOr<BufferLikeType> maybeCallerType =
             bufferization::getBufferType(opOperand->get(), options, state,
                                          invocationStack);
         if (failed(maybeCallerType))
           return failure();
-        callerType = *maybeCallerType;
+        assert(isa<BaseMemRefType>(*maybeCallerType) && "expected memref type");
+        callerType = cast<BaseMemRefType>(*maybeCallerType);
       }
 
       if (!bufferType) {
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index a57d58ab28d28..021a557f68b4b 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -164,8 +164,8 @@ struct SelectOpInterface
     // buffers have different types, they differ only in their layout map. Cast
     // both of them to the most dynamic MemRef type.
     if (trueBuffer.getType() != falseBuffer.getType()) {
-      auto targetType =
-          bufferization::getBufferType(selectOp.getResult(), options, state);
+      auto targetType = bufferization::detail::castToMemRef(
+          bufferization::getBufferType(selectOp.getResult(), options, state));
       if (failed(targetType))
         return failure();
       if (trueBuffer.getType() != *targetType)
@@ -187,10 +187,12 @@ struct SelectOpInterface
                 SmallVector<Value> &invocationStack) const {
     auto selectOp = cast<arith::SelectOp>(op);
     assert(value == selectOp.getResult() && "invalid value");
-    auto trueType = bufferization::getBufferType(
-        selectOp.getTrueValue(), options, state, invocationStack);
-    auto falseType = bufferization::getBufferType(
-        selectOp.getFalseValue(), options, state, invocationStack);
+    auto trueType =
+        bufferization::detail::castToMemRef(bufferization::getBufferType(
+            selectOp.getTrueValue(), options, state, invocationStack));
+    auto falseType =
+        bufferization::detail::castToMemRef(bufferization::getBufferType(
+            selectOp.getFalseValue(), options, state, invocationStack));
     if (failed(trueType) || failed(falseType))
       return failure();
     if (*trueType == *falseType)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 1d6e1bdaf80f5..d00605a7b9a17 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -211,8 +212,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
   // Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
   if (copy)
     return allocTensorOp.getResult();
-  FailureOr<BaseMemRefType> copyBufferType =
-      getBufferType(tensor, options, state);
+  auto copyBufferType =
+      detail::castToMemRef(getBufferType(tensor, options, state));
   if (failed(copyBufferType))
     return failure();
   std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
@@ -673,28 +674,28 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
                                           const BufferizationOptions &options,
                                           const BufferizationState &state) {
 #ifndef NDEBUG
-  auto tensorType = llvm::dyn_cast<TensorType>(value.getType());
+  auto tensorType = llvm::dyn_cast<TensorLikeType>(value.getType());
   assert(tensorType && "unexpected non-tensor type");
 #endif // NDEBUG
 
   // Replace "%t = to_tensor %m" with %m.
   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
-    return toTensorOp.getMemref();
+    return toTensorOp.getBuffer();
 
   // Insert to_buffer op.
   OpBuilder::InsertionGuard g(rewriter);
   setInsertionPointAfter(rewriter, value);
-  FailureOr<BaseMemRefType> memrefType = getBufferType(value, options, state);
-  if (failed(memrefType))
+  FailureOr<BufferLikeType> bufferType = getBufferType(value, options, state);
+  if (failed(bufferType))
     return failure();
-  ensureToBufferOpIsValid(value, *memrefType);
+  ensureToBufferOpIsValid(value, *bufferType);
   return rewriter
-      .create<bufferization::ToBufferOp>(value.getLoc(), *memrefType, value)
+      .create<bufferization::ToBufferOp>(value.getLoc(), *bufferType, value)
       .getResult();
 }
 
 /// Return the buffer type for a given Value (tensor) after bufferization.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 bufferization::getBufferType(Value value, const BufferizationOptions &options,
                              const BufferizationState &state) {
   SmallVector<Value> invocationStack;
@@ -702,11 +703,11 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
 }
 
 /// Return the buffer type for a given Value (tensor) after bufferization.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 bufferization::getBufferType(Value value, const BufferizationOptions &options,
                              const BufferizationState &state,
                              SmallVector<Value> &invocationStack) {
-  assert(llvm::isa<TensorType>(value.getType()) &&
+  assert(llvm::isa<TensorLikeType>(value.getType()) &&
          "unexpected non-tensor type");
   invocationStack.push_back(value);
   auto popFromStack =
@@ -718,13 +719,11 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
   if (bufferizableOp)
     return bufferizableOp.getBufferType(value, options, state, invocationStack);
 
-  // Op is not bufferizable.
-  auto memSpace =
-      options.defaultMemorySpaceFn(cast<TensorType>(value.getType()));
-  if (!memSpace.has_value())
-    return op->emitError("could not infer memory space");
-
-  return getMemRefType(value, options, /*layout=*/{}, *memSpace);
+  // Op is not bufferizable, use conversion interface.
+  bufferization::ConversionInterface iface(value.getContext());
+  return iface.getBufferType(value, options, state, [&](const Twine &message) {
+    return op->emitError(message);
+  });
 }
 
 bool bufferization::hasTensorSemantics(Operation *op) {
@@ -744,12 +743,11 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
   SmallVector<Value> replacements;
   for (OpResult opResult : op->getOpResults()) {
     Value replacement = values[opResult.getResultNumber()];
-    if (llvm::isa<TensorType>(opResult.getType())) {
+    if (llvm::isa<TensorLikeType>(opResult.getType())) {
       // The OpResult is a tensor. Such values are replaced with memrefs during
       // bufferization.
-      assert((llvm::isa<MemRefType>(replacement.getType()) ||
-              llvm::isa<UnrankedMemRefType>(replacement.getType())) &&
-             "tensor op result should be replaced with a memref value");
+      assert(llvm::isa<BufferLikeType>(replacement.getType()) &&
+             "tensor op result should be replaced with a buffer value");
       // The existing uses of the OpResult still expect a tensor. Insert a
       // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
       // loose all of its users and eventually DCE away.
@@ -970,8 +968,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
     // If the OpResult has an equivalent OpOperand, both OpResult and
     // OpOperand bufferize to the exact same buffer type.
     Value equivalentOperand = aliases.getAliases().front().opOperand->get();
-    return getBufferType(equivalentOperand, options, bufferizationState,
-                         invocationStack);
+    return castToMemRef(getBufferType(equivalentOperand, options,
+                                      bufferizationState, invocationStack));
   }
 
   // If we do not know the memory space and there is no default memory space,
@@ -1031,7 +1029,7 @@ bufferization::detail::unknownGe...
[truncated]

@andrey-golubev
Copy link
Contributor Author

@christopherbate I think you might find this interesting also

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many thanks for the upstream contribution! I've seen some changes like this that'd break downstream projects. It is very reasonable because you're improving the bufferization.

I'd appreciate it if you can list down some naming changes in the PR description before you land the PR. It'd save us some time to look at all the details. E.g., the operand name of ToTensorOp becomes buffer; the return type of bufferization::getBufferType becomes FailureOr<BufferLikeType>, etc;

(It is definitely not a requirement, and it is okay if you don't do it. I appreciate your contribution. Thanks again! 🙏)

@andrey-golubev
Copy link
Contributor Author

I'd appreciate it if you can list down some naming changes in the PR description before you land the PR. It'd save us some time to look at all the details. E.g., the operand name of ToTensorOp becomes buffer; the return type of bufferization::getBufferType becomes FailureOr<BufferLikeType>, etc;

Not sure this is good PR description-wise but I mentioned the two changes (actually, what you mention is probably the only two here?).

@andrey-golubev
Copy link
Contributor Author

@matthias-springer (gently pinging)

auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
build($_builder, $_state, rtt, memref, restrict, writeable);
OpBuilder<(ins "Value":$buffer, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
auto rtt = bufferization::detail::getTensorFromBuffer(buffer.getType());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthias-springer I took a look at the current implementation. I think there's a problem with this API if we switch to a BufferLike's function:

  • ToTensor seems to infer the tensor type from buffer type
  • Without options here, there's no (good) way to customize the (reverse) bufferization behavior for builtins

Should this be changed somehow? The things I could think of:

  1. Drop type inference and make ToTensor always be constructed with an explicit type (so, user has to care about reverse bufferization)
  2. Pass bufferization options to the builder
  3. Assume that the current API is the way it is -> meaning that options cannot be used as customization point for builtins also and a different mechanism is needed
  4. something else?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we go with the above-mentioned verifyCompatibleTensorType approach, type inference must be dropped, indeed. I think that's alright. We can add some helper functions to make it easy for folks to migrate their code.

I wouldn't pass the BufferizationOptions to the op builder. I think it's better to explicitly specify the result type when constructing a to_tensor / to_buffer op.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we go with the above-mentioned verifyCompatibleTensorType approach, type inference must be dropped, indeed. I think that's alright. We can add some helper functions to make it easy for folks to migrate their code.

Turns out this builder was such a helper function. In fact, many places seem to just be able to provide the valid tensor type without additional reverse conversions.

class Bufferization_TensorAndBufferMatch<string tensor, string buffer> : PredOpTrait<
"specified tensor and buffer types match",
CPred<
"::mlir::bufferization::detail::typesMatchAfterBufferization("
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthias-springer this would be the other problematic place. With the current approach, we want to validate that bufferization is "valid" on a tensor <-> buffer level.

The current logic checks tensor.getShape() == buffer.getShape() and tensor.getElementType() == buffer.getElementType() this practically means that TensorLike and BufferLike are ShapedType (fine by me), but even that is not enough. We've recently started to experiment with shape bounds and dynamic shapes -> getShape checking might not be sufficient.

Instead, I think we should either restore the old comparison logic (which was changed in ced2fc7) or - more likely - have this put into an interface so that it's a customization point.

But then, which interface? TensorLike? BufferLike? Since it's a type matching function, it's kind of valid to be in both.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about some kind of double dispatch? E.g., for to_tensor:

  • Op verifier checks that operand type implements BufferLikeTypeInterface.
  • Op verifier checks that result type implements TensorLikeTypeInterface.
  • Op verifier calls bufferLikeType.verifyCompatibleTensorType(tensorLikeType). The result is std::optional<LogicalResult>. success means that the type is compatible, failure means that it is incompatible, nullopt means that we don't know.
  • In case of nullopt, call tensorLikeType.verifyCompatibleBufferType(bufferLikeType).
  • If the result is still nullopt, we fail verification because neither of the two types "know" each other.

Copy link
Contributor Author

@andrey-golubev andrey-golubev Jun 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't go with double dispatch to be honest because there's majorly no difference between "tensor equivalent to buffer" vs "buffer equivalent to tensor" (we have both things which do not change between the two calls). For the time being, I guess we just put it somewhere? (either to buffer-like or to tensor-like). Perhaps with more changes it would be clearer what to do here.

Copy link
Member

@matthias-springer matthias-springer Jun 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sry, double dispatch is the wrong name. It's more like querying both interfaces.

The reason why I'm suggesting this is to support custom conversions for builtin types. The type interface implementation of RankedTensorType won't know about your custom buffer type, so it cannot verify type compatibility. But the type interface implementation of your custom buffer type can do the verification.

Copy link
Contributor Author

@andrey-golubev andrey-golubev Jun 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean now. It's actually more (or less?) straightforward:
What we have is:

  1. custom.tensor -> custom.buffer <-- this is "problem-free" (because type interface implementation is on us)
  2. tensor<..., {custom encoding}> -> memref<..., {custom layout}> <-- this is meh

there's never really a situation where we'd bufferize builtin into non-builtin or non-builtin into builtin, but the case is interesting (I'd perhaps add support for this separately if that has any use).

edit: I guess for 2. we can keep the "default" logic - which is what's on main right now - and then let's see where it brings us.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implemented the half of the querying for TensorLikeType.

@andrey-golubev
Copy link
Contributor Author

andrey-golubev commented Jun 17, 2025

With the last update,

  • BufferizationOptions and getMemRefType (helper) had to be updated to work on TensorType instead of Value (actually, I was thinking of doing this for a long time)
  • ToTensorOp's helper builder is dropped

Is it worth splitting these two into separate patches? (Maybe also separate "memref" -> "buffer" renames in ToTensor/ToBuffer ops, etc.).

Alternatively, I guess I can just update the PR description to mention these changes. (Edit: updated the PR description for now)

@aartbik
Copy link
Contributor

aartbik commented Jun 17, 2025

LGTM for the sparse changes (which are rather mechanical)

Copy link
Member

@matthias-springer matthias-springer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall!

@@ -267,7 +268,7 @@ struct BufferizationOptions {
/// Tensor -> MemRef type converter.
/// Parameters: Value, memory space, bufferization options
using UnknownTypeConverterFn = std::function<BaseMemRefType(
Value, Attribute memorySpace, const BufferizationOptions &)>;
TensorType, Attribute memorySpace, const BufferizationOptions &)>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did this change? Can we keep Value here to keep the number of API changes small?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need this together with getMemRefType() API change since TensorLike::getBufferType() works with types (there are no values). see https://github.com/llvm/llvm-project/pull/142986/files#diff-93e08b1e03259ffa2ff33aec1bd8f907f066e1d86cf83cf140fb031fb31c6f3aR71-R72:

llvm::FailureOr<BufferLikeType> getBufferType(mlir::Type tensor, // mlir::Type here now
  const BufferizationOptions &options,
  const BufferizationState &state,
  llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
  auto tensorType = cast<TensorType>(type);

  // ...
      
  return cast<BufferLikeType>(
    getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could separate this into another patch if you prefer

Copy link
Member

@matthias-springer matthias-springer Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this can be a separate PR, let's do that. Integrating bufferization-related changes are often quite difficult for downstream users, so I try to keep them as small as possible, so they can integrate one at a time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

separated - #144658

Value tensor,
Value buffer) {
assert(isa<TensorLikeType>(tensor.getType()) && "expected TensorLikeType");
assert(isa<BufferLikeType>(buffer.getType()) && "expected BufferLikeType");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: dyn_cast, then you don't need to cast a second time below

Copy link
Contributor Author

@andrey-golubev andrey-golubev Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean something like: auto x = dyn_cast<Blah>(y) then assert(x != nullptr)?
I keep them separated like this since release builds would likely strip asserts and so one gets just a cheap llvm::cast<> (no runtime checks afair) instead of dyn_cast/isa.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think llvm::cast still does the check and asserts. Actually, this means you can just remove the assertions entirely.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

const BufferizationState &state,
llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
auto tensorType = cast<TensorType>(tensor);
// Fall back to tensor -> memref conversion.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused about the comment. Why is this a "fall back"?

Copy link
Contributor Author

@andrey-golubev andrey-golubev Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

never mind, outdated comment that i missed when moving stuff around.

edit: removed the comments.

/*methodName=*/"getBufferType",
/*args=*/(ins
"const ::mlir::bufferization::BufferizationOptions &":$options,
"const ::mlir::bufferization::BufferizationState &":$state,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, I'd like to avoid passing the BufferizationState here. Do we really need it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

honestly, i have no idea what this state thing is... I don't really need it myself, imho options is enough. I saw that getBufferType accepts it now, so there's a chance it's actually used by something? (no clue). let me try to make some sense out of it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The state can used during bufferization to store symbol tables, etc. We don't need it during type conversions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome! removed the state from this API.

@@ -738,6 +740,14 @@ AliasingValueList unknownGetAliasingValues(OpOperand &opOperand);
/// This is the default implementation of
/// BufferizableOpInterface::hasTensorSemantics
bool defaultHasTensorSemantics(Operation *op);

/// This is a helper function used when buffer type is guaranteed to be memref.
FailureOr<BaseMemRefType> castToMemRef(FailureOr<BufferLikeType> bufferType);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can you rename this to asMemRefType? The term "cast" is overloaded, it could refer to memref.cast. Let's also document the exact behavior of the function: failure -> failure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

An operation that creates a tensor from a `memref`. The result value is a
tensor whose shape and element type match the memref operand.
An operation that creates a tensor from a buffer. The result value is a
tensor-like type whose shape and element type match the buffer-like operand.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mention in the documentation that operand and result types must be compatible as per TensorLikeTypeInterface::verifyCompatibleBufferType.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. rewrote this piece.

Following the introduction of TensorLike and BufferLike type interfaces
(see 00eaff3), introduce minimal
changes required to bufferize a custom tensor operation into a custom
buffer operation.

To achieve this, a new conversion dialect interface is added that
abstracts away the differences between existing (tensor -> memref) and
custom conversions.

The scope of the changes is intentionally limited (for example,
BufferizableOpInterface is untouched) in order to first understand the
basics and reach consensus design-wise.
The builder is ambiguous given customizable tensor-like -> buffer-like
conversion and is thus removed. The places where reverse bufferization
has to happen rely on the pre-existing functionality.
Noteworthy changes:
* bufferization::getMemRefType() accepts a TensorType instead of Value
  to achieve broader applicability
* BufferizationOptions::UnknownTypeConverterFn accepts a TensorType
  instead of Value to allow it being used in the updated getMemRefType()
@andrey-golubev
Copy link
Contributor Author

andrey-golubev commented Jun 18, 2025

@matthias-springer do you mind merging this also (again, I have no rights :|) or should we wait for some more reviews/feedback? I only see Aart's LGTM for sparsity-related changes.

@matthias-springer matthias-springer merged commit ee070d0 into llvm:main Jun 18, 2025
9 checks passed
@andrey-golubev andrey-golubev deleted the bufferize_custom_type branch June 18, 2025 14:25
basioli-k added a commit to basioli-k/llvm-project that referenced this pull request Jun 18, 2025
basioli-k added a commit that referenced this pull request Jun 18, 2025
bjacob added a commit to iree-org/llvm-project that referenced this pull request Jun 20, 2025
bjacob added a commit to iree-org/llvm-project that referenced this pull request Jun 20, 2025
bjacob added a commit to iree-org/llvm-project that referenced this pull request Jun 20, 2025
bjacob added a commit to iree-org/llvm-project that referenced this pull request Jun 20, 2025
lialan pushed a commit to iree-org/llvm-project that referenced this pull request Jun 23, 2025
lialan pushed a commit to iree-org/llvm-project that referenced this pull request Jun 23, 2025
lialan pushed a commit to iree-org/llvm-project that referenced this pull request Jun 23, 2025
lialan pushed a commit to iree-org/llvm-project that referenced this pull request Jun 23, 2025
lialan pushed a commit to iree-org/llvm-project that referenced this pull request Jun 23, 2025
lialan pushed a commit to iree-org/llvm-project that referenced this pull request Jun 23, 2025
lialan pushed a commit to iree-org/llvm-project that referenced this pull request Jun 23, 2025
lialan pushed a commit to iree-org/llvm-project that referenced this pull request Jun 23, 2025
lialan pushed a commit to iree-org/llvm-project that referenced this pull request Jun 23, 2025
Groverkss pushed a commit to iree-org/llvm-project that referenced this pull request Jun 24, 2025
Groverkss pushed a commit to iree-org/llvm-project that referenced this pull request Jun 24, 2025
umangyadav added a commit to ROCm/rocMLIR that referenced this pull request Jun 24, 2025
Groverkss pushed a commit to iree-org/llvm-project that referenced this pull request Jun 25, 2025
Groverkss pushed a commit to iree-org/llvm-project that referenced this pull request Jun 25, 2025
lialan pushed a commit to iree-org/llvm-project that referenced this pull request Jun 26, 2025
lialan pushed a commit to iree-org/llvm-project that referenced this pull request Jun 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants