|
| 1 | +// RUN: mlir-opt %s -one-shot-bufferize="use-encoding-for-memory-space allow-return-allocs-from-loops allow-unknown-ops" -allow-unregistered-dialect -split-input-file | FileCheck %s |
| 2 | + |
| 3 | +// Here and below, unknown op 'some.use' will force 'bufferization.to_tensor' operations to remain in the body, |
| 4 | +// allowing us to check that the encoding on the '%iter' tensor is correctly preserved. |
| 5 | + |
| 6 | +func.func @scf_for_iter_arg(%arg0: tensor<128xf32, 1>, %arg1: index, %arg2: index, %arg3: index) -> tensor<128xf32, 1> { |
| 7 | + %0 = scf.for %i = %arg1 to %arg2 step %arg3 iter_args(%iter = %arg0) -> tensor<128xf32, 1> { |
| 8 | + %0 = "some.use"(%iter) : (tensor<128xf32, 1>) -> tensor<128xf32, 1> |
| 9 | + scf.yield %0 : tensor<128xf32, 1> |
| 10 | + } |
| 11 | + return %0 : tensor<128xf32, 1> |
| 12 | +} |
| 13 | + |
| 14 | +// CHECK-LABEL: func.func @scf_for_iter_arg |
| 15 | +// CHECK-SAME: (%[[arg0:.+]]: tensor<128xf32, 1 : i64>, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index) |
| 16 | +// CHECK: %[[v0:.+]] = bufferization.to_memref %[[arg0]] : tensor<128xf32, 1 : i64> -> memref<128xf32, strided<[?], offset: ?>, 1> |
| 17 | +// CHECK: %[[alloc:.+]] = memref.alloc() {alignment = 64 : i64} : memref<128xf32, 1> |
| 18 | +// CHECK: memref.copy %[[v0]], %[[alloc]] : memref<128xf32, strided<[?], offset: ?>, 1> to memref<128xf32, 1> |
| 19 | +// CHECK: %[[cast:.+]] = memref.cast %[[alloc]] : memref<128xf32, 1> to memref<128xf32, strided<[?], offset: ?>, 1> |
| 20 | +// CHECK: %[[v1:.+]] = scf.for %{{.+}} = %[[arg1]] to %[[arg2]] step %[[arg3]] iter_args(%[[arg6:.+]] = %[[cast]]) -> (memref<128xf32, strided<[?], offset: ?>, 1>) |
| 21 | +// CHECK-NEXT: %[[v3:.+]] = bufferization.to_tensor %[[arg6]] : memref<128xf32, strided<[?], offset: ?>, 1> -> tensor<128xf32, 1 : i64> |
| 22 | +// CHECK-NEXT: %[[v4:.+]] = "some.use"(%[[v3]]) : (tensor<128xf32, 1 : i64>) -> tensor<128xf32, 1 : i64> |
| 23 | +// CHECK-NEXT: %[[v5:.+]] = bufferization.to_memref %[[v4]] : tensor<128xf32, 1 : i64> -> memref<128xf32, strided<[?], offset: ?>, 1> |
| 24 | +// CHECK-NEXT: scf.yield %[[v5]] : memref<128xf32, strided<[?], offset: ?>, 1> |
| 25 | +// CHECK: %[[v2:.+]] = bufferization.to_tensor %[[v1]] : memref<128xf32, strided<[?], offset: ?>, 1> -> tensor<128xf32, 1 : i64> |
| 26 | +// CHECK: return %[[v2]] : tensor<128xf32, 1 : i64> |
| 27 | + |
| 28 | +// ----- |
| 29 | + |
| 30 | +func.func @scf_forall( |
| 31 | + %idx: index, |
| 32 | + %idx2: index, |
| 33 | + %arg1: tensor<?xf32, 1>, |
| 34 | + %arg2: tensor<?xf32, 1>) -> (tensor<?xf32, 1>) { |
| 35 | + %cst = arith.constant 4.200000e+01 : f32 |
| 36 | + %c0 = arith.constant 0 : index |
| 37 | + %c1 = arith.constant 1 : index |
| 38 | + %2 = scf.forall (%arg3) in (%idx2) shared_outs(%o = %arg2) -> (tensor<?xf32, 1>) { |
| 39 | + %8 = "some.use"(%o) : (tensor<?xf32, 1>) -> tensor<?xf32, 1> |
| 40 | + scf.forall.in_parallel { |
| 41 | + tensor.parallel_insert_slice %8 into %o[5] [%idx] [%c1] : |
| 42 | + tensor<?xf32, 1> into tensor<?xf32, 1> |
| 43 | + } |
| 44 | + } |
| 45 | + return %2 : tensor<?xf32, 1> |
| 46 | +} |
| 47 | + |
| 48 | +// CHECK-LABEL: func.func @scf_forall |
| 49 | +// CHECK: scf.forall |
| 50 | +// CHECK: %[[v2:.+]] = bufferization.to_tensor %{{.+}} : memref<?xf32, 1> -> tensor<?xf32, 1 : i64> |
| 51 | +// CHECK: %[[v3:.+]] = "some.use"(%[[v2]]) : (tensor<?xf32, 1 : i64>) -> tensor<?xf32, 1 : i64> |
| 52 | +// CHECK: bufferization.to_memref %[[v3]] : tensor<?xf32, 1 : i64> -> memref<?xf32, strided<[?], offset: ?>, 1> |
| 53 | +// CHECK: %[[v1:.+]] = bufferization.to_tensor %{{.+}} : memref<?xf32, 1> -> tensor<?xf32, 1 : i64> |
| 54 | +// CHECK: return %[[v1]] : tensor<?xf32, 1 : i64> |
| 55 | + |
| 56 | +// ----- |
| 57 | + |
| 58 | +func.func @scf_execute_region(%arg0: tensor<128xf32, 1>) -> tensor<128xf32, 1> { |
| 59 | + %0 = scf.execute_region -> tensor<128xf32, 1> { |
| 60 | + scf.yield %arg0 : tensor<128xf32, 1> |
| 61 | + } |
| 62 | + %1 = "some.use"(%0) : (tensor<128xf32, 1>) -> tensor<128xf32, 1> |
| 63 | + return %1 : tensor<128xf32, 1> |
| 64 | +} |
| 65 | + |
| 66 | +// CHECK-LABEL: func.func @scf_execute_region |
| 67 | +// CHECK-SAME: (%[[arg0:.+]]: tensor<128xf32, 1 : i64>) |
| 68 | +// CHECK: %[[v0:.+]] = bufferization.to_memref %[[arg0]] : tensor<128xf32, 1 : i64> -> memref<128xf32, strided<[?], offset: ?>, 1> |
| 69 | +// CHECK: %[[v1:.+]] = scf.execute_region -> memref<128xf32, strided<[?], offset: ?>, 1> |
| 70 | +// CHECK: scf.yield %[[v0]] : memref<128xf32, strided<[?], offset: ?>, 1> |
| 71 | +// CHECK: %[[v2:.+]] = bufferization.to_tensor %[[v1]] : memref<128xf32, strided<[?], offset: ?>, 1> -> tensor<128xf32, 1 : i64> |
| 72 | +// CHECK: %[[v3:.+]] = "some.use"(%[[v2]]) : (tensor<128xf32, 1 : i64>) -> tensor<128xf32, 1 : i64> |
| 73 | +// CHECK: return %[[v3]] : tensor<128xf32, 1 : i64> |
0 commit comments