Skip to content

[mlir][bufferization] Fix OneShotBufferize when defaultMemorySpaceFn is used #91524

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

christopherbate
Copy link
Contributor

@christopherbate christopherbate commented May 8, 2024

As described in issue #91518, a previous PR
#78484 introduced the defaultMemorySpaceFn into
bufferization options, allowing one to inform OneShotBufferize that it
should use a specified function to derive the memory space attribute
from the encoding attribute attached to tensor types.

However, introducing this feature exposed unhandled edge cases,
examples of which are introduced by this change in the new test under
test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir.

Fixing the inconsistencies introduced by defaultMemorySpaceFn is pretty
simple. This change:

  • Updates the bufferization.to_memref and bufferization.to_tensor
    operations to explicitly include operand and destination types,
    whereas previously they relied on type inference to deduce the
    tensor types. Since the type inference cannot recover the correct
    tensor encoding/memory space, the operand and result types must be
    explicitly included. This is a small assembly format change, but it
    touches a large number of test files.

  • Makes minor updates to other bufferization functions to handle the
    changes in building the above ops.

  • Updates bufferization of tensor.from_elements to handle memory
    space.

Integration/upgrade guide:

In downstream projects, if you have tests or MLIR files that explicitly use
bufferization.to_tensor or bufferization.to_memref, then update
them to the new assembly format as follows:

%1 = bufferization.to_memref %0 : memref<10xf32>
%2 = bufferization.to_tensor %1 : memref<10xf32>

becomes

%1 = bufferization.to_memref %0 : tensor<10xf32> to memref<10xf32>
%2 = bufferization.to_tensor %0 : memref<10xf32> to tensor<10xf32> 

@llvmbot
Copy link
Member

llvmbot commented May 8, 2024

@llvm/pr-subscribers-mlir-affine
@llvm/pr-subscribers-mlir-tosa
@llvm/pr-subscribers-mlir-scf
@llvm/pr-subscribers-mlir-arith
@llvm/pr-subscribers-mlir-sparse
@llvm/pr-subscribers-mlir-memref
@llvm/pr-subscribers-mlir-amx
@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Christopher Bate (christopherbate)

Changes

As mentioned in the issue described in issue llvm/llvm-project#91518, a previous
PR llvm/llvm-project#78484 introduced the defaultMemorySpaceFn into bufferization
options, allowing one to inform OneShotBufferize that it should use a specified
function to derive the memory space attribute from the encoding attribute attached
to tensor types.

However, introducing this feature exposed a unhandled edge cases, examples of which
are introduced by this change in the new test under
test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir.

Fixing the inconsistencies introduced by defaultMemorySpaceFn is pretty
simple. This change:

  • updates the bufferization.to_memref and bufferization.to_tensor operations
    to explicitly include operand and destination types, whereas previously they
    relied on type inference to deduce the tensor types. Since the type inference
    cannot recover the correct tensor encoding/memory space, the operand and result
    types must be explicitly included.
  • makes minor updates to other bufferization functions to handle the
    changes in building the above ops
  • updates bufferization of tensor.from_elements to handle memory space

Patch is 226.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/91524.diff

68 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h (+6)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+12-6)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td (+4)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+10)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+12-1)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+12-10)
  • (modified) mlir/test/Dialect/Arith/bufferize.mlir (+3-3)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir (+1-1)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir (+15-16)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir (+2-2)
  • (added) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir (+133)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir (+3-3)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir (+3-3)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir (+1-1)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-force-copy-before-write.mlir (+2-2)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir (+1-1)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir (+1-1)
  • (modified) mlir/test/Dialect/Bufferization/canonicalize.mlir (+16-16)
  • (modified) mlir/test/Dialect/Bufferization/ops.mlir (+3-3)
  • (modified) mlir/test/Dialect/ControlFlow/one-shot-bufferize.mlir (+2-2)
  • (modified) mlir/test/Dialect/Func/func-bufferize.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/bufferize.mlir (+5-5)
  • (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+1-1)
  • (modified) mlir/test/Dialect/MemRef/normalize-memrefs.mlir (+2-2)
  • (modified) mlir/test/Dialect/SCF/bufferize.mlir (+6-6)
  • (modified) mlir/test/Dialect/Shape/bufferize.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir (+3-3)
  • (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_matvec_lib.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_sampled_matmul_lib.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_sddmm_lib.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/constant_index_map.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/dense.mlir (+3-3)
  • (modified) mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sorted_coo.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_1d.mlir (+14-14)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_2d.mlir (+39-39)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_3d.mlir (+41-41)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_affine.mlir (+8-8)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_batch.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir (+11-11)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_fusion.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir (+17-17)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_kernels.mlir (+9-9)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_lower.mlir (+4-4)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir (+4-4)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir (+4-4)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_nd.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_outbuf.mlir (+3-3)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_pack.mlir (+6-6)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_perm.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_scalars.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir (+5-5)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_sddmm_org.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/spy_sddmm.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/spy_sddmm_bsr.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/unused-tensor.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir (+14-14)
  • (modified) mlir/test/Dialect/Tensor/bufferize.mlir (+17-17)
  • (modified) mlir/test/Dialect/Tensor/one-shot-bufferize.mlir (+1-1)
  • (modified) mlir/test/Dialect/Vector/bufferize.mlir (+3-3)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/AMX/test-mulf-full.mlir (+2-2)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli-full.mlir (+2-2)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
index 6f19dca2e8222..d6ccbdd7acc1f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
@@ -12,10 +12,16 @@
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/Interfaces/CopyOpInterface.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SubsetOpInterface.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir::bufferization::detail {
+bool tensorTypesMatchUpToEncoding(Type lhs, Type rhs);
+} // namespace mlir::bufferization::detail
 
 //===----------------------------------------------------------------------===//
 // Bufferization Dialect
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 4f609ddff9a41..7be44d566d481 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -388,9 +388,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     BufferizableOpInterface,
     SameOperandsAndResultShape,
     SameOperandsAndResultElementType,
-    TypesMatchWith<"result type matches tensor equivalent of 'memref'",
-                   "memref", "result",
-                   "memref::getTensorTypeFromMemRefType($_self)">
+    AllElementTypesMatch<["memref", "result"]>
   ]> {
   let summary = "create a tensor from a `memref`";
   let description = [{
@@ -477,9 +475,16 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
 
   let assemblyFormat = [{
     $memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
-      `:` type($memref)
+      `:` type($memref) `->` type($result)
   }];
 
+  let builders = [
+    OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
+      auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
+      build($_builder, $_state, rtt, memref, restrict, writeable);
+    }]>
+  ];
+
   let hasCanonicalizer = 1;
   let hasFolder = 1;
 }
@@ -496,7 +501,8 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
     Pure,
     TypesMatchWith<"type of 'tensor' is the tensor equivalent of 'memref'",
                    "memref", "tensor",
-                   "memref::getTensorTypeFromMemRefType($_self)">
+                   "memref::getTensorTypeFromMemRefType($_self)",
+                   "bufferization::detail::tensorTypesMatchUpToEncoding">
   ]> {
   let summary = "cast a tensor to memref";
   let description = [{
@@ -551,7 +557,7 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
   }];
 
   let assemblyFormat = [{
-    $tensor (`read_only` $read_only^)? attr-dict `:` type($memref)
+    $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `->` type($memref)
   }];
 
   let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 75ce85c9128c9..656edbfb3deaa 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -520,6 +520,10 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
            /*default=*/"false",
            "The memory space of an memref types must always be inferred. If "
            "unset, a default memory space of 0 is used otherwise.">,
+    Option<"useEncodingForMemorySpace", "use-encoding-for-memory-space", "bool",
+            /*default=*/"false",
+            "Use the Tensor encoding attribute for the memory space. Exclusive to"
+            " the 'must-infer-memory-space option'">,
     Option<"testAnalysisOnly", "test-analysis-only", "bool",
             /*default=*/"false",
            "Test only: Only run inplaceability analysis and annotate IR">,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index d51d63f243ea0..550ac7e83b9e0 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -719,7 +719,7 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
       // loose all of its users and eventually DCE away.
       rewriter.setInsertionPointAfter(op);
       replacement = rewriter.create<bufferization::ToTensorOp>(
-          replacement.getLoc(), replacement);
+          replacement.getLoc(), opResult.getType(), replacement);
     }
     replacements.push_back(replacement);
   }
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 0acb0c24ab313..bfb742e5e0176 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -23,6 +23,16 @@ using namespace mlir::bufferization;
 // Helper functions
 //===----------------------------------------------------------------------===//
 
+bool bufferization::detail::tensorTypesMatchUpToEncoding(Type lhs, Type rhs) {
+  auto lhsType = cast<ShapedType>(lhs);
+  auto rhsType = cast<ShapedType>(rhs);
+  if (lhsType.getElementType() != rhsType.getElementType())
+    return false;
+  if (lhsType.hasRank() && rhsType.hasRank())
+    return lhsType.getShape() == rhsType.getShape();
+  return true;
+}
+
 FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
     OpBuilder &b, Value value, MemRefType destType,
     const BufferizationOptions &options) {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 7ba347a1f15e4..b43041d629dd3 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -67,10 +67,14 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
                               ValueRange inputs, Location loc) -> Value {
     assert(inputs.size() == 1 && "expected exactly one input");
 
+    // Unranked to ranked casts must be explicit.
+    if (auto inputType = dyn_cast<UnrankedMemRefType>(inputs[0].getType()))
+      return nullptr;
+
     if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
       // MemRef to MemRef cast.
       assert(inputType != type && "expected different types");
-      // Unranked to ranked and ranked to unranked casts must be explicit.
+      // Ranked to unranked casts must be explicit.
       auto rankedDestType = dyn_cast<MemRefType>(type);
       if (!rankedDestType)
         return nullptr;
@@ -222,6 +226,13 @@ struct OneShotBufferizePass
             [](TensorType t) -> std::optional<Attribute> {
           return std::nullopt;
         };
+      } else if (useEncodingForMemorySpace) {
+        opt.defaultMemorySpaceFn =
+            [](TensorType t) -> std::optional<Attribute> {
+          if (auto rtt = dyn_cast<RankedTensorType>(t))
+            return rtt.getEncoding();
+          return std::nullopt;
+        };
       }
       opt.printConflicts = printConflicts;
       opt.testAnalysisOnly = testAnalysisOnly;
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index d078a575f40dd..a46f500b76c3f 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -387,8 +387,8 @@ struct ExtractSliceOpInterface
     if (failed(resultMemrefType))
       return failure();
     Value subView = rewriter.create<memref::SubViewOp>(
-        loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref, mixedOffsets,
-        mixedSizes, mixedStrides);
+        loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref,
+        mixedOffsets, mixedSizes, mixedStrides);
 
     replaceOpWithBufferizedValues(rewriter, op, subView);
     return success();
@@ -407,8 +407,9 @@ struct ExtractSliceOpInterface
     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
     return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
-        extractSliceOp.getType().getShape(), llvm::cast<MemRefType>(*srcMemrefType),
-        mixedOffsets, mixedSizes, mixedStrides));
+        extractSliceOp.getType().getShape(),
+        llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
+        mixedStrides));
   }
 };
 
@@ -478,9 +479,8 @@ struct FromElementsOpInterface
     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
     auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
 
-    // TODO: Implement memory space for this op.
-    if (options.defaultMemorySpaceFn(tensorType) != Attribute())
-      return op->emitError("memory space not implemented yet");
+    std::optional<Attribute> memorySpace =
+        options.defaultMemorySpaceFn(tensorType);
 
     // Allocate a buffer for the result.
     Location loc = op->getLoc();
@@ -491,10 +491,12 @@ struct FromElementsOpInterface
         /*copy=*/false);
     if (failed(tensorAlloc))
       return failure();
-    auto memrefType =
-        MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+    FailureOr<BaseMemRefType> memrefType =
+        bufferization::getBufferType(*tensorAlloc, options);
+    if (failed(memrefType))
+      return failure();
     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
-        op->getLoc(), memrefType, *tensorAlloc);
+        op->getLoc(), *memrefType, *tensorAlloc);
 
     // Case: tensor<0xelem_type>.
     if (fromElementsOp.getElements().empty()) {
diff --git a/mlir/test/Dialect/Arith/bufferize.mlir b/mlir/test/Dialect/Arith/bufferize.mlir
index 944954e9e4edd..31b4577cdd62f 100644
--- a/mlir/test/Dialect/Arith/bufferize.mlir
+++ b/mlir/test/Dialect/Arith/bufferize.mlir
@@ -8,7 +8,7 @@ func.func @index_cast(%tensor: tensor<i32>, %scalar: i32) -> (tensor<index>, ind
   %index_scalar = arith.index_cast %scalar : i32 to index
   return %index_tensor, %index_scalar : tensor<index>, index
 }
-// CHECK:  %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<i32>
+// CHECK:  %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : tensor<i32>
 // CHECK-NEXT: %[[INDEX_MEMREF:.*]] = arith.index_cast %[[MEMREF]]
 // CHECK-SAME:   memref<i32> to memref<index>
 // CHECK-NEXT: %[[INDEX_TENSOR:.*]] = bufferization.to_tensor %[[INDEX_MEMREF]]
@@ -87,8 +87,8 @@ func.func @non_tensor() {
 // CHECK-SAME:                 %[[PRED:.*]]: i1,
 // CHECK-SAME:                 %[[TRUE_VAL:.*]]: tensor<f32>,
 // CHECK-SAME:                 %[[FALSE_VAL:.*]]: tensor<f32>) -> tensor<f32> {
-// CHECK-DAG:           %[[TRUE_VAL_MEMREF:.*]] = bufferization.to_memref %[[TRUE_VAL]] : memref<f32>
-// CHECK-DAG:           %[[FALSE_VAL_MEMREF:.*]] = bufferization.to_memref %[[FALSE_VAL]] : memref<f32>
+// CHECK-DAG:           %[[TRUE_VAL_MEMREF:.*]] = bufferization.to_memref %[[TRUE_VAL]] : tensor<f32>
+// CHECK-DAG:           %[[FALSE_VAL_MEMREF:.*]] = bufferization.to_memref %[[FALSE_VAL]] : tensor<f32>
 // CHECK:           %[[RET_MEMREF:.*]] = arith.select %[[PRED]], %[[TRUE_VAL_MEMREF]], %[[FALSE_VAL_MEMREF]] : memref<f32>
 // CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[RET_MEMREF]] : memref<f32>
 // CHECK:           return %[[RET]] : tensor<f32>
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
index 5293977fe733f..55e086ff0110f 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
@@ -9,7 +9,7 @@
 //  CHECK-NEXT:   %[[clone:.*]] = bufferization.clone %[[m]]
 //  CHECK-NEXT:   return %[[clone]]
 func.func private @no_interface_no_operands(%t : tensor<?x?x?xf16>) -> memref<?x?x?xf16> {
-  %0 = bufferization.to_memref %t : memref<?x?x?xf16>
+  %0 = bufferization.to_memref %t : tensor<?x?x?xf16> -> memref<?x?x?xf16>
   return %0 : memref<?x?x?xf16>
 }
 
diff --git a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
index ff94c1b331d92..500bdb4f9afc5 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
@@ -4,8 +4,8 @@
 // CHECK-SAME:                                     %[[ARG:.*]]: memref<f32>) -> memref<f32> {
 // CHECK:           return %[[ARG]] : memref<f32>
 func.func @eliminate_materializations(%arg0: memref<f32>) -> memref<f32> {
-  %0 = bufferization.to_tensor %arg0 : memref<f32>
-  %1 = bufferization.to_memref %0 : memref<f32>
+  %0 = bufferization.to_tensor %arg0 : memref<f32> -> tensor<f32>
+  %1 = bufferization.to_memref %0 : tensor<f32> -> memref<f32>
   return %1 : memref<f32>
 }
 
@@ -14,14 +14,14 @@ func.func @eliminate_materializations(%arg0: memref<f32>) -> memref<f32> {
 func.func @unable_to_convert_lone_buffer_cast() -> memref<f32> {
   // expected-error @+1 {{failed to legalize operation 'test.source'}}
   %0 = "test.source"() : () -> tensor<f32>
-  %1 = bufferization.to_memref %0 : memref<f32>
+  %1 = bufferization.to_memref %0 : tensor<f32> -> memref<f32>
   return %1 : memref<f32>
 }
 
 // -----
 
 func.func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
-  %0 = bufferization.to_tensor %arg0 : memref<f32>
+  %0 = bufferization.to_tensor %arg0 : memref<f32> -> tensor<f32>
   // expected-error @+1 {{failed to legalize operation 'test.sink'}}
   "test.sink"(%0) : (tensor<f32>) -> ()
   return
@@ -37,8 +37,8 @@ func.func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
 //       CHECK:   memref.copy %[[arg]], %[[alloc]]
 //       CHECK:   return %[[alloc]]
 func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?>>) -> memref<?xf32> {
-  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: ?>>
-  %1 = bufferization.to_memref %0 : memref<?xf32>
+  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: ?>> -> tensor<?xf32>
+  %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
   return %1 : memref<?xf32>
 }
 
@@ -52,8 +52,8 @@ func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?
 //       CHECK:   memref.copy %[[arg]], %[[alloc]]
 //       CHECK:   return %[[alloc]]
 func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offset: ?>>) -> memref<?xf32> {
-  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[100], offset: ?>>
-  %1 = bufferization.to_memref %0 : memref<?xf32>
+  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[100], offset: ?>> -> tensor<?xf32>
+  %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
   return %1 : memref<?xf32>
 }
 
@@ -67,8 +67,8 @@ func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offse
 //       CHECK:   memref.copy %[[arg]], %[[alloc]]
 //       CHECK:   return %[[alloc]]
 func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: 25>>) -> memref<?xf32> {
-  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: 25>>
-  %1 = bufferization.to_memref %0 : memref<?xf32>
+  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: 25>> -> tensor<?xf32>
+  %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
   return %1 : memref<?xf32>
 }
 
@@ -77,9 +77,9 @@ func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset
 // TODO: to_memref with layout maps not supported yet. This should fold to a
 // memref.cast.
 func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, strided<[1], offset: ?>> {
-  %0 = bufferization.to_tensor %m : memref<?xf32>
+  %0 = bufferization.to_tensor %m : memref<?xf32> -> tensor<?xf32>
   // expected-error @+1 {{failed to materialize conversion for result #0 of operation 'bufferization.to_memref' that remained live after conversion}}
-  %1 = bufferization.to_memref %0 : memref<?xf32, strided<[1], offset: ?>>
+  %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32, strided<[1], offset: ?>>
   // expected-note @+1 {{see existing live user here}}
   return %1 : memref<?xf32, strided<[1], offset: ?>>
 }
@@ -87,9 +87,8 @@ func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, stri
 // -----
 
 func.func @illegal_unranked_to_rank(%m: memref<*xf32>) -> memref<?xf32> {
-  // expected-note @+1 {{prior use here}}
-  %0 = bufferization.to_tensor %m : memref<*xf32>
-  // expected-error @+1 {{expects different type than prior uses: 'tensor<?xf32>' vs 'tensor<*xf32>'}}
-  %1 = bufferization.to_memref %0 : memref<?xf32>
+  %0 = bufferization.to_tensor %m : memref<*xf32> -> tensor<?xf32>
+  // expected-error @+1 {{failed to legalize unresolved materialization from 'memref<*xf32>' to 'memref<?xf32>' that remained live after conversion}}
+  %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
   return %1 : memref<?xf32>
 }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
index c3e44c426797f..b74934039506b 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
@@ -96,7 +96,7 @@ func.func @to_memref_not_read_only(%idx : index, %f: f32) -> f32 {
   // Some op may write into the result of to_memref later.
   // CHECK: bufferization.to_memref
   // CHECK-SAME: {__inplace_operands_attr__ = ["false"]}
-  %m = bufferization.to_memref %t : memref<5xf32>
+  %m = bufferization.to_memref %t : tensor<5xf32> -> memref<5xf32>
   %2 = tensor.extract %t[%idx] : tensor<5xf32>
   return %2 : f32
 }
@@ -112,7 +112,7 @@ func.func @to_memref_read_only(%idx : index, %f: f32) -> f32 {
   // Some op may write into the result of to_memref later.
   // CHECK: bufferization.to_memref
   // CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
-  %m = bufferization.to_memref %t {read_only} : memref<5xf32>
+  %m = bufferization.to_memref %t {read_only} : tensor<5xf32> -> memref<5xf32>
   %2 = tensor.extract %t[%idx] : tensor<5xf32>
   return %2 : f32
 }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir
new file mode 100644
index 0000000000000..f892ae95e697d
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir
@@ -0,0 +1,133 @@
+// RUN: mlir-opt %s -one-shot-bufferize="use-encoding-for-memory-space" -split-input-file | FileCheck %s
+
+// TODO: move to tensor dialect tests
+func.func @from_elements(%fill: f32, %f: f32, %idx: index) -> tensor<3xf32, 1> {
+  %t = tensor.from_elements %fill, %fill, %fill : tensor<3xf32, 1>
+  %i = tensor.insert %f into %t[%idx] : tensor<3xf32, 1>
+  return %i : tensor<3xf32, 1>
+}
+
+// CHECK-LABEL: @from_elements
+//  CHECK-SAME: (%[[arg0:.+]]: f32, %[[arg1:.+]]: f32, %[[arg2:.+]]: index) -> tensor<3xf32, 1 : i64>
+//       CHECK:     %[[alloc:.+]] = memref.alloc() {{.*}} : memref<3xf32, 1>
+//       CHECK-DAG:     %[[c0:.+]] = arith.constant 0 : index
+//       CHECK-DAG:     %[[c1:.+]] = arith.constant 1 : index
+//       CHECK-DAG:     %[[c2:.+]] = arith.constant 2 : index
+//       CHECK:     memref.store %[[arg0]], %[[alloc]][%[[c0]]] : memref<3xf32, 1>
+//       CHECK:     memref.store %[[arg0]], %[[alloc]][%[[c1]]] : memref<3xf32, 1>
+//       CHECK:     memref.store %[[arg0]], %[[alloc]][%[[c2]]] : memref<3xf32, 1>
+//       CHECK:     memref.store %[[arg1]], %[[alloc]][%[[arg2]]] : mem...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 8, 2024

@llvm/pr-subscribers-mlir-gpu

Author: Christopher Bate (christopherbate)

Changes

As mentioned in the issue described in issue llvm/llvm-project#91518, a previous
PR llvm/llvm-project#78484 introduced the defaultMemorySpaceFn into bufferization
options, allowing one to inform OneShotBufferize that it should use a specified
function to derive the memory space attribute from the encoding attribute attached
to tensor types.

However, introducing this feature exposed a unhandled edge cases, examples of which
are introduced by this change in the new test under
test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir.

Fixing the inconsistencies introduced by defaultMemorySpaceFn is pretty
simple. This change:

  • updates the bufferization.to_memref and bufferization.to_tensor operations
    to explicitly include operand and destination types, whereas previously they
    relied on type inference to deduce the tensor types. Since the type inference
    cannot recover the correct tensor encoding/memory space, the operand and result
    types must be explicitly included.
  • makes minor updates to other bufferization functions to handle the
    changes in building the above ops
  • updates bufferization of tensor.from_elements to handle memory space

Patch is 226.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/91524.diff

68 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h (+6)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+12-6)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td (+4)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+10)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+12-1)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+12-10)
  • (modified) mlir/test/Dialect/Arith/bufferize.mlir (+3-3)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir (+1-1)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir (+15-16)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir (+2-2)
  • (added) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir (+133)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir (+3-3)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir (+3-3)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir (+1-1)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-force-copy-before-write.mlir (+2-2)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir (+1-1)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir (+1-1)
  • (modified) mlir/test/Dialect/Bufferization/canonicalize.mlir (+16-16)
  • (modified) mlir/test/Dialect/Bufferization/ops.mlir (+3-3)
  • (modified) mlir/test/Dialect/ControlFlow/one-shot-bufferize.mlir (+2-2)
  • (modified) mlir/test/Dialect/Func/func-bufferize.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/bufferize.mlir (+5-5)
  • (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+1-1)
  • (modified) mlir/test/Dialect/MemRef/normalize-memrefs.mlir (+2-2)
  • (modified) mlir/test/Dialect/SCF/bufferize.mlir (+6-6)
  • (modified) mlir/test/Dialect/Shape/bufferize.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir (+3-3)
  • (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_matvec_lib.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_sampled_matmul_lib.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_sddmm_lib.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/constant_index_map.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/dense.mlir (+3-3)
  • (modified) mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sorted_coo.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_1d.mlir (+14-14)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_2d.mlir (+39-39)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_3d.mlir (+41-41)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_affine.mlir (+8-8)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_batch.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir (+11-11)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_fusion.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir (+17-17)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_kernels.mlir (+9-9)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_lower.mlir (+4-4)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir (+4-4)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir (+4-4)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_nd.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_outbuf.mlir (+3-3)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_pack.mlir (+6-6)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_perm.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_scalars.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir (+5-5)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_sddmm_org.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/spy_sddmm.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/spy_sddmm_bsr.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/unused-tensor.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir (+14-14)
  • (modified) mlir/test/Dialect/Tensor/bufferize.mlir (+17-17)
  • (modified) mlir/test/Dialect/Tensor/one-shot-bufferize.mlir (+1-1)
  • (modified) mlir/test/Dialect/Vector/bufferize.mlir (+3-3)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/AMX/test-mulf-full.mlir (+2-2)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli-full.mlir (+2-2)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
index 6f19dca2e8222..d6ccbdd7acc1f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
@@ -12,10 +12,16 @@
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/Interfaces/CopyOpInterface.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SubsetOpInterface.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir::bufferization::detail {
+bool tensorTypesMatchUpToEncoding(Type lhs, Type rhs);
+} // namespace mlir::bufferization::detail
 
 //===----------------------------------------------------------------------===//
 // Bufferization Dialect
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 4f609ddff9a41..7be44d566d481 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -388,9 +388,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     BufferizableOpInterface,
     SameOperandsAndResultShape,
     SameOperandsAndResultElementType,
-    TypesMatchWith<"result type matches tensor equivalent of 'memref'",
-                   "memref", "result",
-                   "memref::getTensorTypeFromMemRefType($_self)">
+    AllElementTypesMatch<["memref", "result"]>
   ]> {
   let summary = "create a tensor from a `memref`";
   let description = [{
@@ -477,9 +475,16 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
 
   let assemblyFormat = [{
     $memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
-      `:` type($memref)
+      `:` type($memref) `->` type($result)
   }];
 
+  let builders = [
+    OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
+      auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
+      build($_builder, $_state, rtt, memref, restrict, writeable);
+    }]>
+  ];
+
   let hasCanonicalizer = 1;
   let hasFolder = 1;
 }
@@ -496,7 +501,8 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
     Pure,
     TypesMatchWith<"type of 'tensor' is the tensor equivalent of 'memref'",
                    "memref", "tensor",
-                   "memref::getTensorTypeFromMemRefType($_self)">
+                   "memref::getTensorTypeFromMemRefType($_self)",
+                   "bufferization::detail::tensorTypesMatchUpToEncoding">
   ]> {
   let summary = "cast a tensor to memref";
   let description = [{
@@ -551,7 +557,7 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
   }];
 
   let assemblyFormat = [{
-    $tensor (`read_only` $read_only^)? attr-dict `:` type($memref)
+    $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `->` type($memref)
   }];
 
   let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 75ce85c9128c9..656edbfb3deaa 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -520,6 +520,10 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
            /*default=*/"false",
            "The memory space of an memref types must always be inferred. If "
            "unset, a default memory space of 0 is used otherwise.">,
+    Option<"useEncodingForMemorySpace", "use-encoding-for-memory-space", "bool",
+            /*default=*/"false",
+            "Use the Tensor encoding attribute for the memory space. Exclusive to"
+            " the 'must-infer-memory-space option'">,
     Option<"testAnalysisOnly", "test-analysis-only", "bool",
             /*default=*/"false",
            "Test only: Only run inplaceability analysis and annotate IR">,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index d51d63f243ea0..550ac7e83b9e0 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -719,7 +719,7 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
       // loose all of its users and eventually DCE away.
       rewriter.setInsertionPointAfter(op);
       replacement = rewriter.create<bufferization::ToTensorOp>(
-          replacement.getLoc(), replacement);
+          replacement.getLoc(), opResult.getType(), replacement);
     }
     replacements.push_back(replacement);
   }
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 0acb0c24ab313..bfb742e5e0176 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -23,6 +23,16 @@ using namespace mlir::bufferization;
 // Helper functions
 //===----------------------------------------------------------------------===//
 
+bool bufferization::detail::tensorTypesMatchUpToEncoding(Type lhs, Type rhs) {
+  auto lhsType = cast<ShapedType>(lhs);
+  auto rhsType = cast<ShapedType>(rhs);
+  if (lhsType.getElementType() != rhsType.getElementType())
+    return false;
+  if (lhsType.hasRank() && rhsType.hasRank())
+    return lhsType.getShape() == rhsType.getShape();
+  return true;
+}
+
 FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
     OpBuilder &b, Value value, MemRefType destType,
     const BufferizationOptions &options) {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 7ba347a1f15e4..b43041d629dd3 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -67,10 +67,14 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
                               ValueRange inputs, Location loc) -> Value {
     assert(inputs.size() == 1 && "expected exactly one input");
 
+    // Unranked to ranked casts must be explicit.
+    if (auto inputType = dyn_cast<UnrankedMemRefType>(inputs[0].getType()))
+      return nullptr;
+
     if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
       // MemRef to MemRef cast.
       assert(inputType != type && "expected different types");
-      // Unranked to ranked and ranked to unranked casts must be explicit.
+      // Ranked to unranked casts must be explicit.
       auto rankedDestType = dyn_cast<MemRefType>(type);
       if (!rankedDestType)
         return nullptr;
@@ -222,6 +226,13 @@ struct OneShotBufferizePass
             [](TensorType t) -> std::optional<Attribute> {
           return std::nullopt;
         };
+      } else if (useEncodingForMemorySpace) {
+        opt.defaultMemorySpaceFn =
+            [](TensorType t) -> std::optional<Attribute> {
+          if (auto rtt = dyn_cast<RankedTensorType>(t))
+            return rtt.getEncoding();
+          return std::nullopt;
+        };
       }
       opt.printConflicts = printConflicts;
       opt.testAnalysisOnly = testAnalysisOnly;
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index d078a575f40dd..a46f500b76c3f 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -387,8 +387,8 @@ struct ExtractSliceOpInterface
     if (failed(resultMemrefType))
       return failure();
     Value subView = rewriter.create<memref::SubViewOp>(
-        loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref, mixedOffsets,
-        mixedSizes, mixedStrides);
+        loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref,
+        mixedOffsets, mixedSizes, mixedStrides);
 
     replaceOpWithBufferizedValues(rewriter, op, subView);
     return success();
@@ -407,8 +407,9 @@ struct ExtractSliceOpInterface
     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
     return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
-        extractSliceOp.getType().getShape(), llvm::cast<MemRefType>(*srcMemrefType),
-        mixedOffsets, mixedSizes, mixedStrides));
+        extractSliceOp.getType().getShape(),
+        llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
+        mixedStrides));
   }
 };
 
@@ -478,9 +479,8 @@ struct FromElementsOpInterface
     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
     auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
 
-    // TODO: Implement memory space for this op.
-    if (options.defaultMemorySpaceFn(tensorType) != Attribute())
-      return op->emitError("memory space not implemented yet");
+    std::optional<Attribute> memorySpace =
+        options.defaultMemorySpaceFn(tensorType);
 
     // Allocate a buffer for the result.
     Location loc = op->getLoc();
@@ -491,10 +491,12 @@ struct FromElementsOpInterface
         /*copy=*/false);
     if (failed(tensorAlloc))
       return failure();
-    auto memrefType =
-        MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+    FailureOr<BaseMemRefType> memrefType =
+        bufferization::getBufferType(*tensorAlloc, options);
+    if (failed(memrefType))
+      return failure();
     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
-        op->getLoc(), memrefType, *tensorAlloc);
+        op->getLoc(), *memrefType, *tensorAlloc);
 
     // Case: tensor<0xelem_type>.
     if (fromElementsOp.getElements().empty()) {
diff --git a/mlir/test/Dialect/Arith/bufferize.mlir b/mlir/test/Dialect/Arith/bufferize.mlir
index 944954e9e4edd..31b4577cdd62f 100644
--- a/mlir/test/Dialect/Arith/bufferize.mlir
+++ b/mlir/test/Dialect/Arith/bufferize.mlir
@@ -8,7 +8,7 @@ func.func @index_cast(%tensor: tensor<i32>, %scalar: i32) -> (tensor<index>, ind
   %index_scalar = arith.index_cast %scalar : i32 to index
   return %index_tensor, %index_scalar : tensor<index>, index
 }
-// CHECK:  %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<i32>
+// CHECK:  %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : tensor<i32>
 // CHECK-NEXT: %[[INDEX_MEMREF:.*]] = arith.index_cast %[[MEMREF]]
 // CHECK-SAME:   memref<i32> to memref<index>
 // CHECK-NEXT: %[[INDEX_TENSOR:.*]] = bufferization.to_tensor %[[INDEX_MEMREF]]
@@ -87,8 +87,8 @@ func.func @non_tensor() {
 // CHECK-SAME:                 %[[PRED:.*]]: i1,
 // CHECK-SAME:                 %[[TRUE_VAL:.*]]: tensor<f32>,
 // CHECK-SAME:                 %[[FALSE_VAL:.*]]: tensor<f32>) -> tensor<f32> {
-// CHECK-DAG:           %[[TRUE_VAL_MEMREF:.*]] = bufferization.to_memref %[[TRUE_VAL]] : memref<f32>
-// CHECK-DAG:           %[[FALSE_VAL_MEMREF:.*]] = bufferization.to_memref %[[FALSE_VAL]] : memref<f32>
+// CHECK-DAG:           %[[TRUE_VAL_MEMREF:.*]] = bufferization.to_memref %[[TRUE_VAL]] : tensor<f32>
+// CHECK-DAG:           %[[FALSE_VAL_MEMREF:.*]] = bufferization.to_memref %[[FALSE_VAL]] : tensor<f32>
 // CHECK:           %[[RET_MEMREF:.*]] = arith.select %[[PRED]], %[[TRUE_VAL_MEMREF]], %[[FALSE_VAL_MEMREF]] : memref<f32>
 // CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[RET_MEMREF]] : memref<f32>
 // CHECK:           return %[[RET]] : tensor<f32>
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
index 5293977fe733f..55e086ff0110f 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
@@ -9,7 +9,7 @@
 //  CHECK-NEXT:   %[[clone:.*]] = bufferization.clone %[[m]]
 //  CHECK-NEXT:   return %[[clone]]
 func.func private @no_interface_no_operands(%t : tensor<?x?x?xf16>) -> memref<?x?x?xf16> {
-  %0 = bufferization.to_memref %t : memref<?x?x?xf16>
+  %0 = bufferization.to_memref %t : tensor<?x?x?xf16> -> memref<?x?x?xf16>
   return %0 : memref<?x?x?xf16>
 }
 
diff --git a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
index ff94c1b331d92..500bdb4f9afc5 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
@@ -4,8 +4,8 @@
 // CHECK-SAME:                                     %[[ARG:.*]]: memref<f32>) -> memref<f32> {
 // CHECK:           return %[[ARG]] : memref<f32>
 func.func @eliminate_materializations(%arg0: memref<f32>) -> memref<f32> {
-  %0 = bufferization.to_tensor %arg0 : memref<f32>
-  %1 = bufferization.to_memref %0 : memref<f32>
+  %0 = bufferization.to_tensor %arg0 : memref<f32> -> tensor<f32>
+  %1 = bufferization.to_memref %0 : tensor<f32> -> memref<f32>
   return %1 : memref<f32>
 }
 
@@ -14,14 +14,14 @@ func.func @eliminate_materializations(%arg0: memref<f32>) -> memref<f32> {
 func.func @unable_to_convert_lone_buffer_cast() -> memref<f32> {
   // expected-error @+1 {{failed to legalize operation 'test.source'}}
   %0 = "test.source"() : () -> tensor<f32>
-  %1 = bufferization.to_memref %0 : memref<f32>
+  %1 = bufferization.to_memref %0 : tensor<f32> -> memref<f32>
   return %1 : memref<f32>
 }
 
 // -----
 
 func.func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
-  %0 = bufferization.to_tensor %arg0 : memref<f32>
+  %0 = bufferization.to_tensor %arg0 : memref<f32> -> tensor<f32>
   // expected-error @+1 {{failed to legalize operation 'test.sink'}}
   "test.sink"(%0) : (tensor<f32>) -> ()
   return
@@ -37,8 +37,8 @@ func.func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
 //       CHECK:   memref.copy %[[arg]], %[[alloc]]
 //       CHECK:   return %[[alloc]]
 func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?>>) -> memref<?xf32> {
-  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: ?>>
-  %1 = bufferization.to_memref %0 : memref<?xf32>
+  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: ?>> -> tensor<?xf32>
+  %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
   return %1 : memref<?xf32>
 }
 
@@ -52,8 +52,8 @@ func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?
 //       CHECK:   memref.copy %[[arg]], %[[alloc]]
 //       CHECK:   return %[[alloc]]
 func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offset: ?>>) -> memref<?xf32> {
-  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[100], offset: ?>>
-  %1 = bufferization.to_memref %0 : memref<?xf32>
+  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[100], offset: ?>> -> tensor<?xf32>
+  %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
   return %1 : memref<?xf32>
 }
 
@@ -67,8 +67,8 @@ func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offse
 //       CHECK:   memref.copy %[[arg]], %[[alloc]]
 //       CHECK:   return %[[alloc]]
 func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: 25>>) -> memref<?xf32> {
-  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: 25>>
-  %1 = bufferization.to_memref %0 : memref<?xf32>
+  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: 25>> -> tensor<?xf32>
+  %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
   return %1 : memref<?xf32>
 }
 
@@ -77,9 +77,9 @@ func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset
 // TODO: to_memref with layout maps not supported yet. This should fold to a
 // memref.cast.
 func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, strided<[1], offset: ?>> {
-  %0 = bufferization.to_tensor %m : memref<?xf32>
+  %0 = bufferization.to_tensor %m : memref<?xf32> -> tensor<?xf32>
   // expected-error @+1 {{failed to materialize conversion for result #0 of operation 'bufferization.to_memref' that remained live after conversion}}
-  %1 = bufferization.to_memref %0 : memref<?xf32, strided<[1], offset: ?>>
+  %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32, strided<[1], offset: ?>>
   // expected-note @+1 {{see existing live user here}}
   return %1 : memref<?xf32, strided<[1], offset: ?>>
 }
@@ -87,9 +87,8 @@ func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, stri
 // -----
 
 func.func @illegal_unranked_to_rank(%m: memref<*xf32>) -> memref<?xf32> {
-  // expected-note @+1 {{prior use here}}
-  %0 = bufferization.to_tensor %m : memref<*xf32>
-  // expected-error @+1 {{expects different type than prior uses: 'tensor<?xf32>' vs 'tensor<*xf32>'}}
-  %1 = bufferization.to_memref %0 : memref<?xf32>
+  %0 = bufferization.to_tensor %m : memref<*xf32> -> tensor<?xf32>
+  // expected-error @+1 {{failed to legalize unresolved materialization from 'memref<*xf32>' to 'memref<?xf32>' that remained live after conversion}}
+  %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
   return %1 : memref<?xf32>
 }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
index c3e44c426797f..b74934039506b 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
@@ -96,7 +96,7 @@ func.func @to_memref_not_read_only(%idx : index, %f: f32) -> f32 {
   // Some op may write into the result of to_memref later.
   // CHECK: bufferization.to_memref
   // CHECK-SAME: {__inplace_operands_attr__ = ["false"]}
-  %m = bufferization.to_memref %t : memref<5xf32>
+  %m = bufferization.to_memref %t : tensor<5xf32> -> memref<5xf32>
   %2 = tensor.extract %t[%idx] : tensor<5xf32>
   return %2 : f32
 }
@@ -112,7 +112,7 @@ func.func @to_memref_read_only(%idx : index, %f: f32) -> f32 {
   // Some op may write into the result of to_memref later.
   // CHECK: bufferization.to_memref
   // CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
-  %m = bufferization.to_memref %t {read_only} : memref<5xf32>
+  %m = bufferization.to_memref %t {read_only} : tensor<5xf32> -> memref<5xf32>
   %2 = tensor.extract %t[%idx] : tensor<5xf32>
   return %2 : f32
 }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir
new file mode 100644
index 0000000000000..f892ae95e697d
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir
@@ -0,0 +1,133 @@
+// RUN: mlir-opt %s -one-shot-bufferize="use-encoding-for-memory-space" -split-input-file | FileCheck %s
+
+// TODO: move to tensor dialect tests
+func.func @from_elements(%fill: f32, %f: f32, %idx: index) -> tensor<3xf32, 1> {
+  %t = tensor.from_elements %fill, %fill, %fill : tensor<3xf32, 1>
+  %i = tensor.insert %f into %t[%idx] : tensor<3xf32, 1>
+  return %i : tensor<3xf32, 1>
+}
+
+// CHECK-LABEL: @from_elements
+//  CHECK-SAME: (%[[arg0:.+]]: f32, %[[arg1:.+]]: f32, %[[arg2:.+]]: index) -> tensor<3xf32, 1 : i64>
+//       CHECK:     %[[alloc:.+]] = memref.alloc() {{.*}} : memref<3xf32, 1>
+//       CHECK-DAG:     %[[c0:.+]] = arith.constant 0 : index
+//       CHECK-DAG:     %[[c1:.+]] = arith.constant 1 : index
+//       CHECK-DAG:     %[[c2:.+]] = arith.constant 2 : index
+//       CHECK:     memref.store %[[arg0]], %[[alloc]][%[[c0]]] : memref<3xf32, 1>
+//       CHECK:     memref.store %[[arg0]], %[[alloc]][%[[c1]]] : memref<3xf32, 1>
+//       CHECK:     memref.store %[[arg0]], %[[alloc]][%[[c2]]] : memref<3xf32, 1>
+//       CHECK:     memref.store %[[arg1]], %[[alloc]][%[[arg2]]] : mem...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 8, 2024

@llvm/pr-subscribers-mlir-bufferization

Author: Christopher Bate (christopherbate)

Changes

As mentioned in the issue described in issue llvm/llvm-project#91518, a previous
PR llvm/llvm-project#78484 introduced the defaultMemorySpaceFn into bufferization
options, allowing one to inform OneShotBufferize that it should use a specified
function to derive the memory space attribute from the encoding attribute attached
to tensor types.

However, introducing this feature exposed a unhandled edge cases, examples of which
are introduced by this change in the new test under
test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir.

Fixing the inconsistencies introduced by defaultMemorySpaceFn is pretty
simple. This change:

  • updates the bufferization.to_memref and bufferization.to_tensor operations
    to explicitly include operand and destination types, whereas previously they
    relied on type inference to deduce the tensor types. Since the type inference
    cannot recover the correct tensor encoding/memory space, the operand and result
    types must be explicitly included.
  • makes minor updates to other bufferization functions to handle the
    changes in building the above ops
  • updates bufferization of tensor.from_elements to handle memory space

Patch is 226.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/91524.diff

68 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h (+6)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+12-6)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td (+4)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+10)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+12-1)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+12-10)
  • (modified) mlir/test/Dialect/Arith/bufferize.mlir (+3-3)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir (+1-1)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir (+15-16)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir (+2-2)
  • (added) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir (+133)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir (+3-3)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir (+3-3)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir (+1-1)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-force-copy-before-write.mlir (+2-2)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir (+1-1)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir (+1-1)
  • (modified) mlir/test/Dialect/Bufferization/canonicalize.mlir (+16-16)
  • (modified) mlir/test/Dialect/Bufferization/ops.mlir (+3-3)
  • (modified) mlir/test/Dialect/ControlFlow/one-shot-bufferize.mlir (+2-2)
  • (modified) mlir/test/Dialect/Func/func-bufferize.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/bufferize.mlir (+5-5)
  • (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+1-1)
  • (modified) mlir/test/Dialect/MemRef/normalize-memrefs.mlir (+2-2)
  • (modified) mlir/test/Dialect/SCF/bufferize.mlir (+6-6)
  • (modified) mlir/test/Dialect/Shape/bufferize.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir (+3-3)
  • (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_matvec_lib.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_sampled_matmul_lib.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_sddmm_lib.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/constant_index_map.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/dense.mlir (+3-3)
  • (modified) mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sorted_coo.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_1d.mlir (+14-14)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_2d.mlir (+39-39)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_3d.mlir (+41-41)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_affine.mlir (+8-8)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_batch.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir (+11-11)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_fusion.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir (+17-17)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_kernels.mlir (+9-9)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_lower.mlir (+4-4)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir (+4-4)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir (+4-4)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_nd.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_outbuf.mlir (+3-3)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_pack.mlir (+6-6)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_perm.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_scalars.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir (+5-5)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_sddmm_org.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/spy_sddmm.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/spy_sddmm_bsr.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/unused-tensor.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir (+14-14)
  • (modified) mlir/test/Dialect/Tensor/bufferize.mlir (+17-17)
  • (modified) mlir/test/Dialect/Tensor/one-shot-bufferize.mlir (+1-1)
  • (modified) mlir/test/Dialect/Vector/bufferize.mlir (+3-3)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/AMX/test-mulf-full.mlir (+2-2)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli-full.mlir (+2-2)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
index 6f19dca2e8222..d6ccbdd7acc1f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
@@ -12,10 +12,16 @@
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/Interfaces/CopyOpInterface.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SubsetOpInterface.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir::bufferization::detail {
+bool tensorTypesMatchUpToEncoding(Type lhs, Type rhs);
+} // namespace mlir::bufferization::detail
 
 //===----------------------------------------------------------------------===//
 // Bufferization Dialect
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 4f609ddff9a41..7be44d566d481 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -388,9 +388,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     BufferizableOpInterface,
     SameOperandsAndResultShape,
     SameOperandsAndResultElementType,
-    TypesMatchWith<"result type matches tensor equivalent of 'memref'",
-                   "memref", "result",
-                   "memref::getTensorTypeFromMemRefType($_self)">
+    AllElementTypesMatch<["memref", "result"]>
   ]> {
   let summary = "create a tensor from a `memref`";
   let description = [{
@@ -477,9 +475,16 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
 
   let assemblyFormat = [{
     $memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
-      `:` type($memref)
+      `:` type($memref) `->` type($result)
   }];
 
+  let builders = [
+    OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
+      auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
+      build($_builder, $_state, rtt, memref, restrict, writeable);
+    }]>
+  ];
+
   let hasCanonicalizer = 1;
   let hasFolder = 1;
 }
@@ -496,7 +501,8 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
     Pure,
     TypesMatchWith<"type of 'tensor' is the tensor equivalent of 'memref'",
                    "memref", "tensor",
-                   "memref::getTensorTypeFromMemRefType($_self)">
+                   "memref::getTensorTypeFromMemRefType($_self)",
+                   "bufferization::detail::tensorTypesMatchUpToEncoding">
   ]> {
   let summary = "cast a tensor to memref";
   let description = [{
@@ -551,7 +557,7 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
   }];
 
   let assemblyFormat = [{
-    $tensor (`read_only` $read_only^)? attr-dict `:` type($memref)
+    $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `->` type($memref)
   }];
 
   let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 75ce85c9128c9..656edbfb3deaa 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -520,6 +520,10 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
            /*default=*/"false",
            "The memory space of an memref types must always be inferred. If "
            "unset, a default memory space of 0 is used otherwise.">,
+    Option<"useEncodingForMemorySpace", "use-encoding-for-memory-space", "bool",
+            /*default=*/"false",
+            "Use the Tensor encoding attribute for the memory space. Exclusive to"
+            " the 'must-infer-memory-space option'">,
     Option<"testAnalysisOnly", "test-analysis-only", "bool",
             /*default=*/"false",
            "Test only: Only run inplaceability analysis and annotate IR">,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index d51d63f243ea0..550ac7e83b9e0 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -719,7 +719,7 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
       // loose all of its users and eventually DCE away.
       rewriter.setInsertionPointAfter(op);
       replacement = rewriter.create<bufferization::ToTensorOp>(
-          replacement.getLoc(), replacement);
+          replacement.getLoc(), opResult.getType(), replacement);
     }
     replacements.push_back(replacement);
   }
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 0acb0c24ab313..bfb742e5e0176 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -23,6 +23,16 @@ using namespace mlir::bufferization;
 // Helper functions
 //===----------------------------------------------------------------------===//
 
+bool bufferization::detail::tensorTypesMatchUpToEncoding(Type lhs, Type rhs) {
+  auto lhsType = cast<ShapedType>(lhs);
+  auto rhsType = cast<ShapedType>(rhs);
+  if (lhsType.getElementType() != rhsType.getElementType())
+    return false;
+  if (lhsType.hasRank() && rhsType.hasRank())
+    return lhsType.getShape() == rhsType.getShape();
+  return true;
+}
+
 FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
     OpBuilder &b, Value value, MemRefType destType,
     const BufferizationOptions &options) {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 7ba347a1f15e4..b43041d629dd3 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -67,10 +67,14 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
                               ValueRange inputs, Location loc) -> Value {
     assert(inputs.size() == 1 && "expected exactly one input");
 
+    // Unranked to ranked casts must be explicit.
+    if (auto inputType = dyn_cast<UnrankedMemRefType>(inputs[0].getType()))
+      return nullptr;
+
     if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
       // MemRef to MemRef cast.
       assert(inputType != type && "expected different types");
-      // Unranked to ranked and ranked to unranked casts must be explicit.
+      // Ranked to unranked casts must be explicit.
       auto rankedDestType = dyn_cast<MemRefType>(type);
       if (!rankedDestType)
         return nullptr;
@@ -222,6 +226,13 @@ struct OneShotBufferizePass
             [](TensorType t) -> std::optional<Attribute> {
           return std::nullopt;
         };
+      } else if (useEncodingForMemorySpace) {
+        opt.defaultMemorySpaceFn =
+            [](TensorType t) -> std::optional<Attribute> {
+          if (auto rtt = dyn_cast<RankedTensorType>(t))
+            return rtt.getEncoding();
+          return std::nullopt;
+        };
       }
       opt.printConflicts = printConflicts;
       opt.testAnalysisOnly = testAnalysisOnly;
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index d078a575f40dd..a46f500b76c3f 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -387,8 +387,8 @@ struct ExtractSliceOpInterface
     if (failed(resultMemrefType))
       return failure();
     Value subView = rewriter.create<memref::SubViewOp>(
-        loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref, mixedOffsets,
-        mixedSizes, mixedStrides);
+        loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref,
+        mixedOffsets, mixedSizes, mixedStrides);
 
     replaceOpWithBufferizedValues(rewriter, op, subView);
     return success();
@@ -407,8 +407,9 @@ struct ExtractSliceOpInterface
     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
     return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
-        extractSliceOp.getType().getShape(), llvm::cast<MemRefType>(*srcMemrefType),
-        mixedOffsets, mixedSizes, mixedStrides));
+        extractSliceOp.getType().getShape(),
+        llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
+        mixedStrides));
   }
 };
 
@@ -478,9 +479,8 @@ struct FromElementsOpInterface
     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
     auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
 
-    // TODO: Implement memory space for this op.
-    if (options.defaultMemorySpaceFn(tensorType) != Attribute())
-      return op->emitError("memory space not implemented yet");
+    std::optional<Attribute> memorySpace =
+        options.defaultMemorySpaceFn(tensorType);
 
     // Allocate a buffer for the result.
     Location loc = op->getLoc();
@@ -491,10 +491,12 @@ struct FromElementsOpInterface
         /*copy=*/false);
     if (failed(tensorAlloc))
       return failure();
-    auto memrefType =
-        MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+    FailureOr<BaseMemRefType> memrefType =
+        bufferization::getBufferType(*tensorAlloc, options);
+    if (failed(memrefType))
+      return failure();
     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
-        op->getLoc(), memrefType, *tensorAlloc);
+        op->getLoc(), *memrefType, *tensorAlloc);
 
     // Case: tensor<0xelem_type>.
     if (fromElementsOp.getElements().empty()) {
diff --git a/mlir/test/Dialect/Arith/bufferize.mlir b/mlir/test/Dialect/Arith/bufferize.mlir
index 944954e9e4edd..31b4577cdd62f 100644
--- a/mlir/test/Dialect/Arith/bufferize.mlir
+++ b/mlir/test/Dialect/Arith/bufferize.mlir
@@ -8,7 +8,7 @@ func.func @index_cast(%tensor: tensor<i32>, %scalar: i32) -> (tensor<index>, ind
   %index_scalar = arith.index_cast %scalar : i32 to index
   return %index_tensor, %index_scalar : tensor<index>, index
 }
-// CHECK:  %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<i32>
+// CHECK:  %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : tensor<i32>
 // CHECK-NEXT: %[[INDEX_MEMREF:.*]] = arith.index_cast %[[MEMREF]]
 // CHECK-SAME:   memref<i32> to memref<index>
 // CHECK-NEXT: %[[INDEX_TENSOR:.*]] = bufferization.to_tensor %[[INDEX_MEMREF]]
@@ -87,8 +87,8 @@ func.func @non_tensor() {
 // CHECK-SAME:                 %[[PRED:.*]]: i1,
 // CHECK-SAME:                 %[[TRUE_VAL:.*]]: tensor<f32>,
 // CHECK-SAME:                 %[[FALSE_VAL:.*]]: tensor<f32>) -> tensor<f32> {
-// CHECK-DAG:           %[[TRUE_VAL_MEMREF:.*]] = bufferization.to_memref %[[TRUE_VAL]] : memref<f32>
-// CHECK-DAG:           %[[FALSE_VAL_MEMREF:.*]] = bufferization.to_memref %[[FALSE_VAL]] : memref<f32>
+// CHECK-DAG:           %[[TRUE_VAL_MEMREF:.*]] = bufferization.to_memref %[[TRUE_VAL]] : tensor<f32>
+// CHECK-DAG:           %[[FALSE_VAL_MEMREF:.*]] = bufferization.to_memref %[[FALSE_VAL]] : tensor<f32>
 // CHECK:           %[[RET_MEMREF:.*]] = arith.select %[[PRED]], %[[TRUE_VAL_MEMREF]], %[[FALSE_VAL_MEMREF]] : memref<f32>
 // CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[RET_MEMREF]] : memref<f32>
 // CHECK:           return %[[RET]] : tensor<f32>
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
index 5293977fe733f..55e086ff0110f 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
@@ -9,7 +9,7 @@
 //  CHECK-NEXT:   %[[clone:.*]] = bufferization.clone %[[m]]
 //  CHECK-NEXT:   return %[[clone]]
 func.func private @no_interface_no_operands(%t : tensor<?x?x?xf16>) -> memref<?x?x?xf16> {
-  %0 = bufferization.to_memref %t : memref<?x?x?xf16>
+  %0 = bufferization.to_memref %t : tensor<?x?x?xf16> -> memref<?x?x?xf16>
   return %0 : memref<?x?x?xf16>
 }
 
diff --git a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
index ff94c1b331d92..500bdb4f9afc5 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
@@ -4,8 +4,8 @@
 // CHECK-SAME:                                     %[[ARG:.*]]: memref<f32>) -> memref<f32> {
 // CHECK:           return %[[ARG]] : memref<f32>
 func.func @eliminate_materializations(%arg0: memref<f32>) -> memref<f32> {
-  %0 = bufferization.to_tensor %arg0 : memref<f32>
-  %1 = bufferization.to_memref %0 : memref<f32>
+  %0 = bufferization.to_tensor %arg0 : memref<f32> -> tensor<f32>
+  %1 = bufferization.to_memref %0 : tensor<f32> -> memref<f32>
   return %1 : memref<f32>
 }
 
@@ -14,14 +14,14 @@ func.func @eliminate_materializations(%arg0: memref<f32>) -> memref<f32> {
 func.func @unable_to_convert_lone_buffer_cast() -> memref<f32> {
   // expected-error @+1 {{failed to legalize operation 'test.source'}}
   %0 = "test.source"() : () -> tensor<f32>
-  %1 = bufferization.to_memref %0 : memref<f32>
+  %1 = bufferization.to_memref %0 : tensor<f32> -> memref<f32>
   return %1 : memref<f32>
 }
 
 // -----
 
 func.func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
-  %0 = bufferization.to_tensor %arg0 : memref<f32>
+  %0 = bufferization.to_tensor %arg0 : memref<f32> -> tensor<f32>
   // expected-error @+1 {{failed to legalize operation 'test.sink'}}
   "test.sink"(%0) : (tensor<f32>) -> ()
   return
@@ -37,8 +37,8 @@ func.func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
 //       CHECK:   memref.copy %[[arg]], %[[alloc]]
 //       CHECK:   return %[[alloc]]
 func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?>>) -> memref<?xf32> {
-  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: ?>>
-  %1 = bufferization.to_memref %0 : memref<?xf32>
+  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: ?>> -> tensor<?xf32>
+  %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
   return %1 : memref<?xf32>
 }
 
@@ -52,8 +52,8 @@ func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?
 //       CHECK:   memref.copy %[[arg]], %[[alloc]]
 //       CHECK:   return %[[alloc]]
 func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offset: ?>>) -> memref<?xf32> {
-  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[100], offset: ?>>
-  %1 = bufferization.to_memref %0 : memref<?xf32>
+  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[100], offset: ?>> -> tensor<?xf32>
+  %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
   return %1 : memref<?xf32>
 }
 
@@ -67,8 +67,8 @@ func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offse
 //       CHECK:   memref.copy %[[arg]], %[[alloc]]
 //       CHECK:   return %[[alloc]]
 func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: 25>>) -> memref<?xf32> {
-  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: 25>>
-  %1 = bufferization.to_memref %0 : memref<?xf32>
+  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: 25>> -> tensor<?xf32>
+  %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
   return %1 : memref<?xf32>
 }
 
@@ -77,9 +77,9 @@ func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset
 // TODO: to_memref with layout maps not supported yet. This should fold to a
 // memref.cast.
 func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, strided<[1], offset: ?>> {
-  %0 = bufferization.to_tensor %m : memref<?xf32>
+  %0 = bufferization.to_tensor %m : memref<?xf32> -> tensor<?xf32>
   // expected-error @+1 {{failed to materialize conversion for result #0 of operation 'bufferization.to_memref' that remained live after conversion}}
-  %1 = bufferization.to_memref %0 : memref<?xf32, strided<[1], offset: ?>>
+  %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32, strided<[1], offset: ?>>
   // expected-note @+1 {{see existing live user here}}
   return %1 : memref<?xf32, strided<[1], offset: ?>>
 }
@@ -87,9 +87,8 @@ func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, stri
 // -----
 
 func.func @illegal_unranked_to_rank(%m: memref<*xf32>) -> memref<?xf32> {
-  // expected-note @+1 {{prior use here}}
-  %0 = bufferization.to_tensor %m : memref<*xf32>
-  // expected-error @+1 {{expects different type than prior uses: 'tensor<?xf32>' vs 'tensor<*xf32>'}}
-  %1 = bufferization.to_memref %0 : memref<?xf32>
+  %0 = bufferization.to_tensor %m : memref<*xf32> -> tensor<?xf32>
+  // expected-error @+1 {{failed to legalize unresolved materialization from 'memref<*xf32>' to 'memref<?xf32>' that remained live after conversion}}
+  %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
   return %1 : memref<?xf32>
 }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
index c3e44c426797f..b74934039506b 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
@@ -96,7 +96,7 @@ func.func @to_memref_not_read_only(%idx : index, %f: f32) -> f32 {
   // Some op may write into the result of to_memref later.
   // CHECK: bufferization.to_memref
   // CHECK-SAME: {__inplace_operands_attr__ = ["false"]}
-  %m = bufferization.to_memref %t : memref<5xf32>
+  %m = bufferization.to_memref %t : tensor<5xf32> -> memref<5xf32>
   %2 = tensor.extract %t[%idx] : tensor<5xf32>
   return %2 : f32
 }
@@ -112,7 +112,7 @@ func.func @to_memref_read_only(%idx : index, %f: f32) -> f32 {
   // Some op may write into the result of to_memref later.
   // CHECK: bufferization.to_memref
   // CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
-  %m = bufferization.to_memref %t {read_only} : memref<5xf32>
+  %m = bufferization.to_memref %t {read_only} : tensor<5xf32> -> memref<5xf32>
   %2 = tensor.extract %t[%idx] : tensor<5xf32>
   return %2 : f32
 }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir
new file mode 100644
index 0000000000000..f892ae95e697d
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir
@@ -0,0 +1,133 @@
+// RUN: mlir-opt %s -one-shot-bufferize="use-encoding-for-memory-space" -split-input-file | FileCheck %s
+
+// TODO: move to tensor dialect tests
+func.func @from_elements(%fill: f32, %f: f32, %idx: index) -> tensor<3xf32, 1> {
+  %t = tensor.from_elements %fill, %fill, %fill : tensor<3xf32, 1>
+  %i = tensor.insert %f into %t[%idx] : tensor<3xf32, 1>
+  return %i : tensor<3xf32, 1>
+}
+
+// CHECK-LABEL: @from_elements
+//  CHECK-SAME: (%[[arg0:.+]]: f32, %[[arg1:.+]]: f32, %[[arg2:.+]]: index) -> tensor<3xf32, 1 : i64>
+//       CHECK:     %[[alloc:.+]] = memref.alloc() {{.*}} : memref<3xf32, 1>
+//       CHECK-DAG:     %[[c0:.+]] = arith.constant 0 : index
+//       CHECK-DAG:     %[[c1:.+]] = arith.constant 1 : index
+//       CHECK-DAG:     %[[c2:.+]] = arith.constant 2 : index
+//       CHECK:     memref.store %[[arg0]], %[[alloc]][%[[c0]]] : memref<3xf32, 1>
+//       CHECK:     memref.store %[[arg0]], %[[alloc]][%[[c1]]] : memref<3xf32, 1>
+//       CHECK:     memref.store %[[arg0]], %[[alloc]][%[[c2]]] : memref<3xf32, 1>
+//       CHECK:     memref.store %[[arg1]], %[[alloc]][%[[arg2]]] : mem...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 8, 2024

@llvm/pr-subscribers-mlir-func

Author: Christopher Bate (christopherbate)

Changes

As mentioned in the issue described in issue llvm/llvm-project#91518, a previous
PR llvm/llvm-project#78484 introduced the defaultMemorySpaceFn into bufferization
options, allowing one to inform OneShotBufferize that it should use a specified
function to derive the memory space attribute from the encoding attribute attached
to tensor types.

However, introducing this feature exposed a unhandled edge cases, examples of which
are introduced by this change in the new test under
test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir.

Fixing the inconsistencies introduced by defaultMemorySpaceFn is pretty
simple. This change:

  • updates the bufferization.to_memref and bufferization.to_tensor operations
    to explicitly include operand and destination types, whereas previously they
    relied on type inference to deduce the tensor types. Since the type inference
    cannot recover the correct tensor encoding/memory space, the operand and result
    types must be explicitly included.
  • makes minor updates to other bufferization functions to handle the
    changes in building the above ops
  • updates bufferization of tensor.from_elements to handle memory space

Patch is 226.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/91524.diff

68 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h (+6)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+12-6)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td (+4)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+10)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+12-1)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+12-10)
  • (modified) mlir/test/Dialect/Arith/bufferize.mlir (+3-3)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir (+1-1)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir (+15-16)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir (+2-2)
  • (added) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir (+133)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir (+3-3)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir (+3-3)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir (+1-1)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-force-copy-before-write.mlir (+2-2)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir (+1-1)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir (+1-1)
  • (modified) mlir/test/Dialect/Bufferization/canonicalize.mlir (+16-16)
  • (modified) mlir/test/Dialect/Bufferization/ops.mlir (+3-3)
  • (modified) mlir/test/Dialect/ControlFlow/one-shot-bufferize.mlir (+2-2)
  • (modified) mlir/test/Dialect/Func/func-bufferize.mlir (+1-1)
  • (modified) mlir/test/Dialect/Linalg/bufferize.mlir (+5-5)
  • (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+1-1)
  • (modified) mlir/test/Dialect/MemRef/normalize-memrefs.mlir (+2-2)
  • (modified) mlir/test/Dialect/SCF/bufferize.mlir (+6-6)
  • (modified) mlir/test/Dialect/Shape/bufferize.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir (+3-3)
  • (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_matvec_lib.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_sampled_matmul_lib.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_sddmm_lib.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/constant_index_map.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/dense.mlir (+3-3)
  • (modified) mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sorted_coo.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_1d.mlir (+14-14)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_2d.mlir (+39-39)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_3d.mlir (+41-41)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_affine.mlir (+8-8)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_batch.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir (+11-11)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_fusion.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir (+17-17)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_kernels.mlir (+9-9)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_lower.mlir (+4-4)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir (+4-4)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir (+4-4)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_nd.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_outbuf.mlir (+3-3)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_pack.mlir (+6-6)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_parallel_reduce.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_perm.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_scalars.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir (+5-5)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_sddmm_org.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/spy_sddmm.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/spy_sddmm_bsr.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/unused-tensor.mlir (+2-2)
  • (modified) mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir (+14-14)
  • (modified) mlir/test/Dialect/Tensor/bufferize.mlir (+17-17)
  • (modified) mlir/test/Dialect/Tensor/one-shot-bufferize.mlir (+1-1)
  • (modified) mlir/test/Dialect/Vector/bufferize.mlir (+3-3)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/AMX/test-mulf-full.mlir (+2-2)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli-full.mlir (+2-2)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
index 6f19dca2e8222..d6ccbdd7acc1f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
@@ -12,10 +12,16 @@
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/Interfaces/CopyOpInterface.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SubsetOpInterface.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir::bufferization::detail {
+bool tensorTypesMatchUpToEncoding(Type lhs, Type rhs);
+} // namespace mlir::bufferization::detail
 
 //===----------------------------------------------------------------------===//
 // Bufferization Dialect
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 4f609ddff9a41..7be44d566d481 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -388,9 +388,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     BufferizableOpInterface,
     SameOperandsAndResultShape,
     SameOperandsAndResultElementType,
-    TypesMatchWith<"result type matches tensor equivalent of 'memref'",
-                   "memref", "result",
-                   "memref::getTensorTypeFromMemRefType($_self)">
+    AllElementTypesMatch<["memref", "result"]>
   ]> {
   let summary = "create a tensor from a `memref`";
   let description = [{
@@ -477,9 +475,16 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
 
   let assemblyFormat = [{
     $memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
-      `:` type($memref)
+      `:` type($memref) `->` type($result)
   }];
 
+  let builders = [
+    OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
+      auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
+      build($_builder, $_state, rtt, memref, restrict, writeable);
+    }]>
+  ];
+
   let hasCanonicalizer = 1;
   let hasFolder = 1;
 }
@@ -496,7 +501,8 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
     Pure,
     TypesMatchWith<"type of 'tensor' is the tensor equivalent of 'memref'",
                    "memref", "tensor",
-                   "memref::getTensorTypeFromMemRefType($_self)">
+                   "memref::getTensorTypeFromMemRefType($_self)",
+                   "bufferization::detail::tensorTypesMatchUpToEncoding">
   ]> {
   let summary = "cast a tensor to memref";
   let description = [{
@@ -551,7 +557,7 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
   }];
 
   let assemblyFormat = [{
-    $tensor (`read_only` $read_only^)? attr-dict `:` type($memref)
+    $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `->` type($memref)
   }];
 
   let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index 75ce85c9128c9..656edbfb3deaa 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -520,6 +520,10 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
            /*default=*/"false",
            "The memory space of an memref types must always be inferred. If "
            "unset, a default memory space of 0 is used otherwise.">,
+    Option<"useEncodingForMemorySpace", "use-encoding-for-memory-space", "bool",
+            /*default=*/"false",
+            "Use the Tensor encoding attribute for the memory space. Exclusive to"
+            " the 'must-infer-memory-space option'">,
     Option<"testAnalysisOnly", "test-analysis-only", "bool",
             /*default=*/"false",
            "Test only: Only run inplaceability analysis and annotate IR">,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index d51d63f243ea0..550ac7e83b9e0 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -719,7 +719,7 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
       // loose all of its users and eventually DCE away.
       rewriter.setInsertionPointAfter(op);
       replacement = rewriter.create<bufferization::ToTensorOp>(
-          replacement.getLoc(), replacement);
+          replacement.getLoc(), opResult.getType(), replacement);
     }
     replacements.push_back(replacement);
   }
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 0acb0c24ab313..bfb742e5e0176 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -23,6 +23,16 @@ using namespace mlir::bufferization;
 // Helper functions
 //===----------------------------------------------------------------------===//
 
+bool bufferization::detail::tensorTypesMatchUpToEncoding(Type lhs, Type rhs) {
+  auto lhsType = cast<ShapedType>(lhs);
+  auto rhsType = cast<ShapedType>(rhs);
+  if (lhsType.getElementType() != rhsType.getElementType())
+    return false;
+  if (lhsType.hasRank() && rhsType.hasRank())
+    return lhsType.getShape() == rhsType.getShape();
+  return true;
+}
+
 FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
     OpBuilder &b, Value value, MemRefType destType,
     const BufferizationOptions &options) {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 7ba347a1f15e4..b43041d629dd3 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -67,10 +67,14 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
                               ValueRange inputs, Location loc) -> Value {
     assert(inputs.size() == 1 && "expected exactly one input");
 
+    // Unranked to ranked casts must be explicit.
+    if (auto inputType = dyn_cast<UnrankedMemRefType>(inputs[0].getType()))
+      return nullptr;
+
     if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
       // MemRef to MemRef cast.
       assert(inputType != type && "expected different types");
-      // Unranked to ranked and ranked to unranked casts must be explicit.
+      // Ranked to unranked casts must be explicit.
       auto rankedDestType = dyn_cast<MemRefType>(type);
       if (!rankedDestType)
         return nullptr;
@@ -222,6 +226,13 @@ struct OneShotBufferizePass
             [](TensorType t) -> std::optional<Attribute> {
           return std::nullopt;
         };
+      } else if (useEncodingForMemorySpace) {
+        opt.defaultMemorySpaceFn =
+            [](TensorType t) -> std::optional<Attribute> {
+          if (auto rtt = dyn_cast<RankedTensorType>(t))
+            return rtt.getEncoding();
+          return std::nullopt;
+        };
       }
       opt.printConflicts = printConflicts;
       opt.testAnalysisOnly = testAnalysisOnly;
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index d078a575f40dd..a46f500b76c3f 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -387,8 +387,8 @@ struct ExtractSliceOpInterface
     if (failed(resultMemrefType))
       return failure();
     Value subView = rewriter.create<memref::SubViewOp>(
-        loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref, mixedOffsets,
-        mixedSizes, mixedStrides);
+        loc, llvm::cast<MemRefType>(*resultMemrefType), *srcMemref,
+        mixedOffsets, mixedSizes, mixedStrides);
 
     replaceOpWithBufferizedValues(rewriter, op, subView);
     return success();
@@ -407,8 +407,9 @@ struct ExtractSliceOpInterface
     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
     return cast<BaseMemRefType>(memref::SubViewOp::inferRankReducedResultType(
-        extractSliceOp.getType().getShape(), llvm::cast<MemRefType>(*srcMemrefType),
-        mixedOffsets, mixedSizes, mixedStrides));
+        extractSliceOp.getType().getShape(),
+        llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
+        mixedStrides));
   }
 };
 
@@ -478,9 +479,8 @@ struct FromElementsOpInterface
     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
     auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
 
-    // TODO: Implement memory space for this op.
-    if (options.defaultMemorySpaceFn(tensorType) != Attribute())
-      return op->emitError("memory space not implemented yet");
+    std::optional<Attribute> memorySpace =
+        options.defaultMemorySpaceFn(tensorType);
 
     // Allocate a buffer for the result.
     Location loc = op->getLoc();
@@ -491,10 +491,12 @@ struct FromElementsOpInterface
         /*copy=*/false);
     if (failed(tensorAlloc))
       return failure();
-    auto memrefType =
-        MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+    FailureOr<BaseMemRefType> memrefType =
+        bufferization::getBufferType(*tensorAlloc, options);
+    if (failed(memrefType))
+      return failure();
     Value buffer = rewriter.create<bufferization::ToMemrefOp>(
-        op->getLoc(), memrefType, *tensorAlloc);
+        op->getLoc(), *memrefType, *tensorAlloc);
 
     // Case: tensor<0xelem_type>.
     if (fromElementsOp.getElements().empty()) {
diff --git a/mlir/test/Dialect/Arith/bufferize.mlir b/mlir/test/Dialect/Arith/bufferize.mlir
index 944954e9e4edd..31b4577cdd62f 100644
--- a/mlir/test/Dialect/Arith/bufferize.mlir
+++ b/mlir/test/Dialect/Arith/bufferize.mlir
@@ -8,7 +8,7 @@ func.func @index_cast(%tensor: tensor<i32>, %scalar: i32) -> (tensor<index>, ind
   %index_scalar = arith.index_cast %scalar : i32 to index
   return %index_tensor, %index_scalar : tensor<index>, index
 }
-// CHECK:  %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<i32>
+// CHECK:  %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : tensor<i32>
 // CHECK-NEXT: %[[INDEX_MEMREF:.*]] = arith.index_cast %[[MEMREF]]
 // CHECK-SAME:   memref<i32> to memref<index>
 // CHECK-NEXT: %[[INDEX_TENSOR:.*]] = bufferization.to_tensor %[[INDEX_MEMREF]]
@@ -87,8 +87,8 @@ func.func @non_tensor() {
 // CHECK-SAME:                 %[[PRED:.*]]: i1,
 // CHECK-SAME:                 %[[TRUE_VAL:.*]]: tensor<f32>,
 // CHECK-SAME:                 %[[FALSE_VAL:.*]]: tensor<f32>) -> tensor<f32> {
-// CHECK-DAG:           %[[TRUE_VAL_MEMREF:.*]] = bufferization.to_memref %[[TRUE_VAL]] : memref<f32>
-// CHECK-DAG:           %[[FALSE_VAL_MEMREF:.*]] = bufferization.to_memref %[[FALSE_VAL]] : memref<f32>
+// CHECK-DAG:           %[[TRUE_VAL_MEMREF:.*]] = bufferization.to_memref %[[TRUE_VAL]] : tensor<f32>
+// CHECK-DAG:           %[[FALSE_VAL_MEMREF:.*]] = bufferization.to_memref %[[FALSE_VAL]] : tensor<f32>
 // CHECK:           %[[RET_MEMREF:.*]] = arith.select %[[PRED]], %[[TRUE_VAL_MEMREF]], %[[FALSE_VAL_MEMREF]] : memref<f32>
 // CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[RET_MEMREF]] : memref<f32>
 // CHECK:           return %[[RET]] : tensor<f32>
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
index 5293977fe733f..55e086ff0110f 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-other.mlir
@@ -9,7 +9,7 @@
 //  CHECK-NEXT:   %[[clone:.*]] = bufferization.clone %[[m]]
 //  CHECK-NEXT:   return %[[clone]]
 func.func private @no_interface_no_operands(%t : tensor<?x?x?xf16>) -> memref<?x?x?xf16> {
-  %0 = bufferization.to_memref %t : memref<?x?x?xf16>
+  %0 = bufferization.to_memref %t : tensor<?x?x?xf16> -> memref<?x?x?xf16>
   return %0 : memref<?x?x?xf16>
 }
 
diff --git a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
index ff94c1b331d92..500bdb4f9afc5 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir
@@ -4,8 +4,8 @@
 // CHECK-SAME:                                     %[[ARG:.*]]: memref<f32>) -> memref<f32> {
 // CHECK:           return %[[ARG]] : memref<f32>
 func.func @eliminate_materializations(%arg0: memref<f32>) -> memref<f32> {
-  %0 = bufferization.to_tensor %arg0 : memref<f32>
-  %1 = bufferization.to_memref %0 : memref<f32>
+  %0 = bufferization.to_tensor %arg0 : memref<f32> -> tensor<f32>
+  %1 = bufferization.to_memref %0 : tensor<f32> -> memref<f32>
   return %1 : memref<f32>
 }
 
@@ -14,14 +14,14 @@ func.func @eliminate_materializations(%arg0: memref<f32>) -> memref<f32> {
 func.func @unable_to_convert_lone_buffer_cast() -> memref<f32> {
   // expected-error @+1 {{failed to legalize operation 'test.source'}}
   %0 = "test.source"() : () -> tensor<f32>
-  %1 = bufferization.to_memref %0 : memref<f32>
+  %1 = bufferization.to_memref %0 : tensor<f32> -> memref<f32>
   return %1 : memref<f32>
 }
 
 // -----
 
 func.func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
-  %0 = bufferization.to_tensor %arg0 : memref<f32>
+  %0 = bufferization.to_tensor %arg0 : memref<f32> -> tensor<f32>
   // expected-error @+1 {{failed to legalize operation 'test.sink'}}
   "test.sink"(%0) : (tensor<f32>) -> ()
   return
@@ -37,8 +37,8 @@ func.func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
 //       CHECK:   memref.copy %[[arg]], %[[alloc]]
 //       CHECK:   return %[[alloc]]
 func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?>>) -> memref<?xf32> {
-  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: ?>>
-  %1 = bufferization.to_memref %0 : memref<?xf32>
+  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: ?>> -> tensor<?xf32>
+  %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
   return %1 : memref<?xf32>
 }
 
@@ -52,8 +52,8 @@ func.func @dyn_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: ?
 //       CHECK:   memref.copy %[[arg]], %[[alloc]]
 //       CHECK:   return %[[alloc]]
 func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offset: ?>>) -> memref<?xf32> {
-  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[100], offset: ?>>
-  %1 = bufferization.to_memref %0 : memref<?xf32>
+  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[100], offset: ?>> -> tensor<?xf32>
+  %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
   return %1 : memref<?xf32>
 }
 
@@ -67,8 +67,8 @@ func.func @fancy_layout_to_no_layout_cast(%m: memref<?xf32, strided<[100], offse
 //       CHECK:   memref.copy %[[arg]], %[[alloc]]
 //       CHECK:   return %[[alloc]]
 func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset: 25>>) -> memref<?xf32> {
-  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: 25>>
-  %1 = bufferization.to_memref %0 : memref<?xf32>
+  %0 = bufferization.to_tensor %m : memref<?xf32, strided<[1], offset: 25>> -> tensor<?xf32>
+  %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
   return %1 : memref<?xf32>
 }
 
@@ -77,9 +77,9 @@ func.func @static_layout_to_no_layout_cast(%m: memref<?xf32, strided<[1], offset
 // TODO: to_memref with layout maps not supported yet. This should fold to a
 // memref.cast.
 func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, strided<[1], offset: ?>> {
-  %0 = bufferization.to_tensor %m : memref<?xf32>
+  %0 = bufferization.to_tensor %m : memref<?xf32> -> tensor<?xf32>
   // expected-error @+1 {{failed to materialize conversion for result #0 of operation 'bufferization.to_memref' that remained live after conversion}}
-  %1 = bufferization.to_memref %0 : memref<?xf32, strided<[1], offset: ?>>
+  %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32, strided<[1], offset: ?>>
   // expected-note @+1 {{see existing live user here}}
   return %1 : memref<?xf32, strided<[1], offset: ?>>
 }
@@ -87,9 +87,8 @@ func.func @no_layout_to_dyn_layout_cast(%m: memref<?xf32>) -> memref<?xf32, stri
 // -----
 
 func.func @illegal_unranked_to_rank(%m: memref<*xf32>) -> memref<?xf32> {
-  // expected-note @+1 {{prior use here}}
-  %0 = bufferization.to_tensor %m : memref<*xf32>
-  // expected-error @+1 {{expects different type than prior uses: 'tensor<?xf32>' vs 'tensor<*xf32>'}}
-  %1 = bufferization.to_memref %0 : memref<?xf32>
+  %0 = bufferization.to_tensor %m : memref<*xf32> -> tensor<?xf32>
+  // expected-error @+1 {{failed to legalize unresolved materialization from 'memref<*xf32>' to 'memref<?xf32>' that remained live after conversion}}
+  %1 = bufferization.to_memref %0 : tensor<?xf32> -> memref<?xf32>
   return %1 : memref<?xf32>
 }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
index c3e44c426797f..b74934039506b 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir
@@ -96,7 +96,7 @@ func.func @to_memref_not_read_only(%idx : index, %f: f32) -> f32 {
   // Some op may write into the result of to_memref later.
   // CHECK: bufferization.to_memref
   // CHECK-SAME: {__inplace_operands_attr__ = ["false"]}
-  %m = bufferization.to_memref %t : memref<5xf32>
+  %m = bufferization.to_memref %t : tensor<5xf32> -> memref<5xf32>
   %2 = tensor.extract %t[%idx] : tensor<5xf32>
   return %2 : f32
 }
@@ -112,7 +112,7 @@ func.func @to_memref_read_only(%idx : index, %f: f32) -> f32 {
   // Some op may write into the result of to_memref later.
   // CHECK: bufferization.to_memref
   // CHECK-SAME: {__inplace_operands_attr__ = ["true"]}
-  %m = bufferization.to_memref %t {read_only} : memref<5xf32>
+  %m = bufferization.to_memref %t {read_only} : tensor<5xf32> -> memref<5xf32>
   %2 = tensor.extract %t[%idx] : tensor<5xf32>
   return %2 : f32
 }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir
new file mode 100644
index 0000000000000..f892ae95e697d
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir
@@ -0,0 +1,133 @@
+// RUN: mlir-opt %s -one-shot-bufferize="use-encoding-for-memory-space" -split-input-file | FileCheck %s
+
+// TODO: move to tensor dialect tests
+func.func @from_elements(%fill: f32, %f: f32, %idx: index) -> tensor<3xf32, 1> {
+  %t = tensor.from_elements %fill, %fill, %fill : tensor<3xf32, 1>
+  %i = tensor.insert %f into %t[%idx] : tensor<3xf32, 1>
+  return %i : tensor<3xf32, 1>
+}
+
+// CHECK-LABEL: @from_elements
+//  CHECK-SAME: (%[[arg0:.+]]: f32, %[[arg1:.+]]: f32, %[[arg2:.+]]: index) -> tensor<3xf32, 1 : i64>
+//       CHECK:     %[[alloc:.+]] = memref.alloc() {{.*}} : memref<3xf32, 1>
+//       CHECK-DAG:     %[[c0:.+]] = arith.constant 0 : index
+//       CHECK-DAG:     %[[c1:.+]] = arith.constant 1 : index
+//       CHECK-DAG:     %[[c2:.+]] = arith.constant 2 : index
+//       CHECK:     memref.store %[[arg0]], %[[alloc]][%[[c0]]] : memref<3xf32, 1>
+//       CHECK:     memref.store %[[arg0]], %[[alloc]][%[[c1]]] : memref<3xf32, 1>
+//       CHECK:     memref.store %[[arg0]], %[[alloc]][%[[c2]]] : memref<3xf32, 1>
+//       CHECK:     memref.store %[[arg1]], %[[alloc]][%[[arg2]]] : mem...
[truncated]

@christopherbate
Copy link
Contributor Author

@matthias-springer Can you review? The vast majority of changes here are just to update tests due to change in assembly formats for bufferization.to_tensor|to_memref.

Copy link
Member

@matthias-springer matthias-springer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add extra unit tests for scf.for (with tensor iter_arg that has an encoding), scf.if and scf.execute_region (with tensor result). I think these bufferization pattern must be updated because they create new to_tensor ops.

// -----

func.func @alloc_tesor_copy_from_non_default_space(%arg0: tensor<128xf32, 1>) -> tensor<128xf32, 2> {
%0 = bufferization.alloc_tensor() copy(%arg0) {memory_space = 2 : i64} : tensor<128xf32, 1>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the meaning of this op? The two memory spaces here are inconsistent.

Copy link
Contributor Author

@christopherbate christopherbate May 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since having a copy generates a memref.copy, the intent is to copy data between memory spaces. However, bufferization.alloc_tensor only prints one type for both copy operand and the result. Therefore we will need to update ASM format for bufferization.alloc_tensor as well. That will enable dropping the tensor.cast from this test.

@matthias-springer
Copy link
Member

Can you add extra unit tests for scf.for (with tensor iter_arg that has an encoding), scf.if and scf.execute_region (with tensor result). I think these bufferization pattern must be updated because they create new to_tensor ops.

Ideally, to make this PR not even larger, if an update to these patterns is indeed needed, you could update those patterns in a first, separate PR. I think not much has to change, you just have to take the tensor types from the input IR instead of inferring them in the op builder.

@aartbik
Copy link
Contributor

aartbik commented Jun 10, 2024

lgtm for all sparse test signature changes

@christopherbate christopherbate force-pushed the fix-bufferization-tensor-type-encoding branch 2 times, most recently from f90c2f1 to 543e5a0 Compare August 31, 2024 04:29
@christopherbate christopherbate force-pushed the fix-bufferization-tensor-type-encoding branch from 543e5a0 to 7780f90 Compare September 8, 2024 19:32
@srcarroll
Copy link
Contributor

is this PR dead? wondering why it's been sitting here approved for months

@christopherbate
Copy link
Contributor Author

It's not dead; I have just been pressed for time. Let me rebase and address Matthias' last comments this week.

@CoTinker
Copy link
Contributor

Hi, this PR is useful for me, do you have any time to rebase it. Thanks.

@christopherbate christopherbate force-pushed the fix-bufferization-tensor-type-encoding branch from ce6c22d to c627d06 Compare November 26, 2024 01:59
@christopherbate christopherbate force-pushed the fix-bufferization-tensor-type-encoding branch from 38cc33f to 4d81769 Compare November 26, 2024 04:00
…` is used

As described in issue llvm#91518, a previous PR
llvm#78484 introduced the `defaultMemorySpaceFn` into bufferization
options, allowing one to inform OneShotBufferize that it should use a specified
function to derive the memory space attribute from the encoding attribute
attached to tensor types.

However, introducing this feature exposed unhandled edge cases, examples of
which are introduced by this change in the new test under
`test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir`.

Fixing the inconsistencies introduced by `defaultMemorySpaceFn` is pretty
simple. This change:

- Updates the `bufferization.to_memref` and `bufferization.to_tensor` operations
  to explicitly include operand and destination types, whereas previously they
  relied on type inference to deduce the tensor types. Since the type
  inference cannot recover the correct tensor encoding/memory space, the
  operand and result types must be explicitly included. This is a small assembly
  format change, but it touches a large number of test files.

- Makes minor updates to other bufferization functions to handle the
  changes in building the above ops.

- Updates bufferization of `tensor.from_elements` to handle memory space.
@christopherbate christopherbate force-pushed the fix-bufferization-tensor-type-encoding branch from 92d486c to bbe57de Compare November 26, 2024 16:27
@christopherbate christopherbate merged commit ced2fc7 into llvm:main Nov 26, 2024
8 checks passed
@christopherbate christopherbate deleted the fix-bufferization-tensor-type-encoding branch November 26, 2024 16:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants