Skip to content

[mlir][bufferization] Canonicalize to_memref(to_tensor(x)) to a CopyO… #126692

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 11, 2025

Conversation

amrami
Copy link
Contributor

@amrami amrami commented Feb 11, 2025

…p if memory spaces differ

@llvmbot llvmbot added mlir mlir:bufferization Bufferization infrastructure labels Feb 11, 2025
@llvmbot
Copy link
Member

llvmbot commented Feb 11, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-bufferization

Author: Maya Amrami (amrami)

Changes

…p if memory spaces differ


Full diff: https://github.com/llvm/llvm-project/pull/126692.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+1-3)
  • (modified) mlir/test/Dialect/Bufferization/canonicalize.mlir (+12-10)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 6be55a1d282240b..4fce9be390bd6c5 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -28,11 +28,9 @@ FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
     const BufferizationOptions &options) {
   auto srcType = llvm::cast<MemRefType>(value.getType());
 
-  // Element type, rank and memory space must match.
+  // Element type and rank must match.
   if (srcType.getElementType() != destType.getElementType())
     return failure();
-  if (srcType.getMemorySpace() != destType.getMemorySpace())
-    return failure();
   if (srcType.getRank() != destType.getRank())
     return failure();
 
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index 3ebc1e4fa8dea34..833d83be89b5aab 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -27,18 +27,20 @@ func.func @buffer_cast_of_tensor_load(%arg0: memref<?xf32>) -> memref<?xf32> {
 
 // -----
 
-// If the memrefs are not the same type, don't fold them.
-// If the memrefs are not cast-compatible (e.g. different address space), don't
-// canonicalize them either.
-// CHECK-LABEL: func @no_fold_buffer_cast_of_tensor_load(
+// If the memrefs are not cast-compatible (e.g. different address space),
+// canonicalize them to copy.
+// CHECK-LABEL: func @canonicalize_buffer_cast_of_tensor_load_different_address_space(
 //  CHECK-SAME:   %[[MEMREF_ADDRSPACE2:.*]]: memref<?xf32, 2>)
 //  CHECK-SAME:     -> memref<?xf32, 7> {
-//       CHECK: %[[TENSOR:.*]] = bufferization.to_tensor
-//  CHECK-SAME:   %[[MEMREF_ADDRSPACE2]] : memref<?xf32, 2> to tensor<?xf32, 7 : i64>
-//       CHECK: %[[MEMREF_ADDRSPACE7:.*]] = bufferization.to_memref
-//  CHECK-SAME:   %[[TENSOR]] : tensor<?xf32, 7 : i64> to memref<?xf32, 7>
-//       CHECK: return %[[MEMREF_ADDRSPACE7]]
-func.func @no_fold_buffer_cast_of_tensor_load(%arg0: memref<?xf32, 2>)
+//  CHECK-NOT: bufferization.to_tensor
+//  CHECK-NOT: bufferization.to_memref
+//      CHECK: %[[C0:.*]] = arith.constant 0 : index
+//      CHECK: %[[DIM:.*]] = memref.dim %[[MEMREF_ADDRSPACE2]], %[[C0]] : memref<?xf32, 2>
+//      CHECK: %[[MEMREF_ADDRSPACE7:.*]] = memref.alloc(%[[DIM]]) : memref<?xf32, 7>
+//      CHECK: memref.copy %[[MEMREF_ADDRSPACE2]], %[[MEMREF_ADDRSPACE7]]
+// CHECK-SAME:   memref<?xf32, 2> to memref<?xf32, 7>
+//      CHECK: return %[[MEMREF_ADDRSPACE7]]
+func.func @canonicalize_buffer_cast_of_tensor_load_different_address_space(%arg0: memref<?xf32, 2>)
     -> memref<?xf32, 7> {
   %0 = bufferization.to_tensor %arg0 : memref<?xf32, 2> to tensor<?xf32, 7>
   %1 = bufferization.to_memref %0 : tensor<?xf32, 7> to memref<?xf32, 7>

@amrami amrami force-pushed the mamrami/bufferization_add_copy branch 3 times, most recently from 5609cf2 to 82491fa Compare February 11, 2025 07:39
@amrami amrami force-pushed the mamrami/bufferization_add_copy branch 2 times, most recently from 6c061dd to 160a658 Compare February 11, 2025 09:58
@amirBish amirBish merged commit 8cf4c55 into llvm:main Feb 11, 2025
11 checks passed
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants