@@ -28,17 +28,20 @@ func.func @buffer_cast_of_tensor_load(%arg0: memref<?xf32>) -> memref<?xf32> {
28
28
// -----
29
29
30
30
// If the memrefs are not the same type, don't fold them.
31
- // If the memrefs are not cast-compatible (e.g. different address space), don't
32
- // canonicalize them either .
33
- // CHECK-LABEL: func @no_fold_buffer_cast_of_tensor_load (
31
+ // If the memrefs are not cast-compatible but one can be copied into the other
32
+ // (e.g. different address space), canonicalize them to add + copy .
33
+ // CHECK-LABEL: func @canonicalize_buffer_cast_of_tensor_load_different_address_space (
34
34
// CHECK-SAME: %[[MEMREF_ADDRSPACE2:.*]]: memref<?xf32, 2>)
35
35
// CHECK-SAME: -> memref<?xf32, 7> {
36
- // CHECK: %[[TENSOR:.*]] = bufferization.to_tensor
37
- // CHECK-SAME: %[[MEMREF_ADDRSPACE2]] : memref<?xf32, 2> to tensor<?xf32, 7 : i64>
38
- // CHECK: %[[MEMREF_ADDRSPACE7:.*]] = bufferization.to_memref
39
- // CHECK-SAME: %[[TENSOR]] : tensor<?xf32, 7 : i64> to memref<?xf32, 7>
40
- // CHECK: return %[[MEMREF_ADDRSPACE7]]
41
- func.func @no_fold_buffer_cast_of_tensor_load (%arg0: memref <?xf32 , 2 >)
36
+ // CHECK-NOT: bufferization.to_tensor
37
+ // CHECK-NOT: bufferization.to_memref
38
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
39
+ // CHECK: %[[DIM:.*]] = memref.dim %[[MEMREF_ADDRSPACE2]], %[[C0]] : memref<?xf32, 2>
40
+ // CHECK: %[[MEMREF_ADDRSPACE7:.*]] = memref.alloc(%[[DIM]]) : memref<?xf32, 7>
41
+ // CHECK: memref.copy %[[MEMREF_ADDRSPACE2]], %[[MEMREF_ADDRSPACE7]]
42
+ // CHECK-SAME: memref<?xf32, 2> to memref<?xf32, 7>
43
+ // CHECK: return %[[MEMREF_ADDRSPACE7]]
44
+ func.func @canonicalize_buffer_cast_of_tensor_load_different_address_space (%arg0: memref <?xf32 , 2 >)
42
45
-> memref <?xf32 , 7 > {
43
46
%0 = bufferization.to_tensor %arg0 : memref <?xf32 , 2 > to tensor <?xf32 , 7 >
44
47
%1 = bufferization.to_memref %0 : tensor <?xf32 , 7 > to memref <?xf32 , 7 >
0 commit comments