Skip to content

Commit 8cf4c55

Browse files
authored
[mlir][bufferization] Canonicalize to_memref(to_tensor(x)) to a CopyO… (llvm#126692)
…p if memory spaces differ
1 parent 19556ec commit 8cf4c55

File tree

2 files changed

+13
-12
lines changed

2 files changed

+13
-12
lines changed

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,9 @@ FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
2828
const BufferizationOptions &options) {
2929
auto srcType = llvm::cast<MemRefType>(value.getType());
3030

31-
// Element type, rank and memory space must match.
31+
// Element type and rank must match.
3232
if (srcType.getElementType() != destType.getElementType())
3333
return failure();
34-
if (srcType.getMemorySpace() != destType.getMemorySpace())
35-
return failure();
3634
if (srcType.getRank() != destType.getRank())
3735
return failure();
3836

mlir/test/Dialect/Bufferization/canonicalize.mlir

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,20 @@ func.func @buffer_cast_of_tensor_load(%arg0: memref<?xf32>) -> memref<?xf32> {
2828
// -----
2929

3030
// If the memrefs are not the same type, don't fold them.
31-
// If the memrefs are not cast-compatible (e.g. different address space), don't
32-
// canonicalize them either.
33-
// CHECK-LABEL: func @no_fold_buffer_cast_of_tensor_load(
31+
// If the memrefs are not cast-compatible but one can be copied into the other
32+
// (e.g. different address space), canonicalize them to add + copy.
33+
// CHECK-LABEL: func @canonicalize_buffer_cast_of_tensor_load_different_address_space(
3434
// CHECK-SAME: %[[MEMREF_ADDRSPACE2:.*]]: memref<?xf32, 2>)
3535
// CHECK-SAME: -> memref<?xf32, 7> {
36-
// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor
37-
// CHECK-SAME: %[[MEMREF_ADDRSPACE2]] : memref<?xf32, 2> to tensor<?xf32, 7 : i64>
38-
// CHECK: %[[MEMREF_ADDRSPACE7:.*]] = bufferization.to_memref
39-
// CHECK-SAME: %[[TENSOR]] : tensor<?xf32, 7 : i64> to memref<?xf32, 7>
40-
// CHECK: return %[[MEMREF_ADDRSPACE7]]
41-
func.func @no_fold_buffer_cast_of_tensor_load(%arg0: memref<?xf32, 2>)
36+
// CHECK-NOT: bufferization.to_tensor
37+
// CHECK-NOT: bufferization.to_memref
38+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
39+
// CHECK: %[[DIM:.*]] = memref.dim %[[MEMREF_ADDRSPACE2]], %[[C0]] : memref<?xf32, 2>
40+
// CHECK: %[[MEMREF_ADDRSPACE7:.*]] = memref.alloc(%[[DIM]]) : memref<?xf32, 7>
41+
// CHECK: memref.copy %[[MEMREF_ADDRSPACE2]], %[[MEMREF_ADDRSPACE7]]
42+
// CHECK-SAME: memref<?xf32, 2> to memref<?xf32, 7>
43+
// CHECK: return %[[MEMREF_ADDRSPACE7]]
44+
func.func @canonicalize_buffer_cast_of_tensor_load_different_address_space(%arg0: memref<?xf32, 2>)
4245
-> memref<?xf32, 7> {
4346
%0 = bufferization.to_tensor %arg0 : memref<?xf32, 2> to tensor<?xf32, 7>
4447
%1 = bufferization.to_memref %0 : tensor<?xf32, 7> to memref<?xf32, 7>

0 commit comments

Comments
 (0)