Skip to content

Commit 22eeae8

Browse files
Add more/better tests for ToTensorOp creation in SCF op bufferizations
1 parent c6c0504 commit 22eeae8

File tree

4 files changed

+108
-51
lines changed

4 files changed

+108
-51
lines changed

mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -486,15 +486,17 @@ getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands,
486486
/// ToTensorOps, so that the block body can be moved over to the new op.
487487
static SmallVector<Value>
488488
getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
489+
Block::BlockArgListType oldBbArgs,
489490
const DenseSet<int64_t> &tensorIndices) {
490491
SmallVector<Value> result;
491492
for (const auto &it : llvm::enumerate(bbArgs)) {
492493
size_t idx = it.index();
493494
Value val = it.value();
494495
if (tensorIndices.contains(idx)) {
495-
result.push_back(
496-
rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val)
497-
.getResult());
496+
result.push_back(rewriter
497+
.create<bufferization::ToTensorOp>(
498+
val.getLoc(), oldBbArgs[idx].getType(), val)
499+
.getResult());
498500
} else {
499501
result.push_back(val);
500502
}
@@ -764,7 +766,8 @@ struct ForOpInterface
764766
// iter_args of the new loop in ToTensorOps.
765767
rewriter.setInsertionPointToStart(loopBody);
766768
SmallVector<Value> iterArgs =
767-
getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices);
769+
getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(),
770+
forOp.getRegionIterArgs(), indices);
768771
iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
769772

770773
// Move loop body to new loop.
@@ -1001,16 +1004,18 @@ struct WhileOpInterface
10011004
// The old block uses tensors, so wrap the (memref) bbArgs of the new block
10021005
// in ToTensorOps.
10031006
rewriter.setInsertionPointToStart(newBeforeBody);
1004-
SmallVector<Value> newBeforeArgs = getBbArgReplacements(
1005-
rewriter, newWhileOp.getBeforeArguments(), indicesBefore);
1007+
SmallVector<Value> newBeforeArgs =
1008+
getBbArgReplacements(rewriter, newWhileOp.getBeforeArguments(),
1009+
whileOp.getBeforeArguments(), indicesBefore);
10061010
rewriter.mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs);
10071011

10081012
// Set up new iter_args and move the loop body block to the new op.
10091013
// The old block uses tensors, so wrap the (memref) bbArgs of the new block
10101014
// in ToTensorOps.
10111015
rewriter.setInsertionPointToStart(newAfterBody);
1012-
SmallVector<Value> newAfterArgs = getBbArgReplacements(
1013-
rewriter, newWhileOp.getAfterArguments(), indicesAfter);
1016+
SmallVector<Value> newAfterArgs =
1017+
getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(),
1018+
whileOp.getAfterArguments(), indicesAfter);
10141019
rewriter.mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs);
10151020

10161021
// Replace loop results.
@@ -1256,8 +1261,8 @@ struct ForallOpInterface
12561261
forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
12571262
BlockArgument bbArg = std::get<0>(it);
12581263
Value buffer = std::get<1>(it);
1259-
Value bufferAsTensor =
1260-
rewriter.create<ToTensorOp>(forallOp.getLoc(), buffer);
1264+
Value bufferAsTensor = rewriter.create<ToTensorOp>(
1265+
forallOp.getLoc(), bbArg.getType(), buffer);
12611266
bbArg.replaceAllUsesWith(bufferAsTensor);
12621267
}
12631268

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,5 @@
11
// RUN: mlir-opt %s -one-shot-bufferize="use-encoding-for-memory-space" -split-input-file | FileCheck %s
22

3-
// TODO: move to tensor dialect tests
4-
func.func @from_elements(%fill: f32, %f: f32, %idx: index) -> tensor<3xf32, 1> {
5-
%t = tensor.from_elements %fill, %fill, %fill : tensor<3xf32, 1>
6-
%i = tensor.insert %f into %t[%idx] : tensor<3xf32, 1>
7-
return %i : tensor<3xf32, 1>
8-
}
9-
10-
// CHECK-LABEL: @from_elements
11-
// CHECK-SAME: (%[[arg0:.+]]: f32, %[[arg1:.+]]: f32, %[[arg2:.+]]: index) -> tensor<3xf32, 1 : i64>
12-
// CHECK: %[[alloc:.+]] = memref.alloc() {{.*}} : memref<3xf32, 1>
13-
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
14-
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
15-
// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index
16-
// CHECK: memref.store %[[arg0]], %[[alloc]][%[[c0]]] : memref<3xf32, 1>
17-
// CHECK: memref.store %[[arg0]], %[[alloc]][%[[c1]]] : memref<3xf32, 1>
18-
// CHECK: memref.store %[[arg0]], %[[alloc]][%[[c2]]] : memref<3xf32, 1>
19-
// CHECK: memref.store %[[arg1]], %[[alloc]][%[[arg2]]] : memref<3xf32, 1>
20-
// CHECK: %[[v0:.+]] = bufferization.to_tensor %[[alloc]] : memref<3xf32, 1> -> tensor<3xf32, 1 : i64>
21-
// CHECK: return %[[v0]] : tensor<3xf32, 1 : i64>
22-
23-
// -----
24-
253
func.func @alloc_tesor_with_space_no_encoding() -> tensor<128xf32> {
264
%0 = bufferization.alloc_tensor() {memory_space = 1 : i64} : tensor<128xf32>
275
return %0 : tensor<128xf32>
@@ -131,22 +109,3 @@ func.func @materialize_in_destination(%arg0: tensor<128xf32, 1>) -> tensor<128xf
131109
// CHECK: memref.copy %[[v0]], %[[alloc]] : memref<128xf32, strided<[?], offset: ?>, 1> to memref<128xf32, 2>
132110
// CHECK: %[[v1:.+]] = bufferization.to_tensor %[[alloc]] : memref<128xf32, 2> -> tensor<128xf32, 2 : i64>
133111
// CHECK: return %[[v1]] : tensor<128xf32, 2 : i64>
134-
135-
// -----
136-
137-
func.func @scf_for_iter_arg(%arg0: tensor<128xf32, 1>, %arg1: index, %arg2: index, %arg3: index, %arg4: f32) -> tensor<128xf32, 1> {
138-
%0 = scf.for %i = %arg1 to %arg2 step %arg3 iter_args(%iter = %arg0) -> tensor<128xf32, 1> {
139-
%0 = tensor.insert %arg4 into %iter[%i] : tensor<128xf32, 1>
140-
scf.yield %0 : tensor<128xf32, 1>
141-
}
142-
return %0 : tensor<128xf32, 1>
143-
}
144-
145-
// -----
146-
147-
func.func @scf_execute_region(%arg0: tensor<128xf32, 1>) -> tensor<128xf32, 1> {
148-
%0 = scf.execute_region -> tensor<128xf32, 1> {
149-
scf.yield %arg0 : tensor<128xf32, 1>
150-
}
151-
return %0 : tensor<128xf32, 1>
152-
}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// RUN: mlir-opt %s -one-shot-bufferize="use-encoding-for-memory-space allow-return-allocs-from-loops allow-unknown-ops" -allow-unregistered-dialect -split-input-file | FileCheck %s
2+
3+
// Here and below, unknown op 'some.use' will force 'bufferization.to_tensor' operations to remain in the body,
4+
// allowing us to check that the encoding on the '%iter' tensor is correctly preserved.
5+
6+
func.func @scf_for_iter_arg(%arg0: tensor<128xf32, 1>, %arg1: index, %arg2: index, %arg3: index) -> tensor<128xf32, 1> {
7+
%0 = scf.for %i = %arg1 to %arg2 step %arg3 iter_args(%iter = %arg0) -> tensor<128xf32, 1> {
8+
%0 = "some.use"(%iter) : (tensor<128xf32, 1>) -> tensor<128xf32, 1>
9+
scf.yield %0 : tensor<128xf32, 1>
10+
}
11+
return %0 : tensor<128xf32, 1>
12+
}
13+
14+
// CHECK-LABEL: func.func @scf_for_iter_arg
15+
// CHECK-SAME: (%[[arg0:.+]]: tensor<128xf32, 1 : i64>, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index)
16+
// CHECK: %[[v0:.+]] = bufferization.to_memref %[[arg0]] : tensor<128xf32, 1 : i64> -> memref<128xf32, strided<[?], offset: ?>, 1>
17+
// CHECK: %[[alloc:.+]] = memref.alloc() {alignment = 64 : i64} : memref<128xf32, 1>
18+
// CHECK: memref.copy %[[v0]], %[[alloc]] : memref<128xf32, strided<[?], offset: ?>, 1> to memref<128xf32, 1>
19+
// CHECK: %[[cast:.+]] = memref.cast %[[alloc]] : memref<128xf32, 1> to memref<128xf32, strided<[?], offset: ?>, 1>
20+
// CHECK: %[[v1:.+]] = scf.for %{{.+}} = %[[arg1]] to %[[arg2]] step %[[arg3]] iter_args(%[[arg6:.+]] = %[[cast]]) -> (memref<128xf32, strided<[?], offset: ?>, 1>)
21+
// CHECK-NEXT: %[[v3:.+]] = bufferization.to_tensor %[[arg6]] : memref<128xf32, strided<[?], offset: ?>, 1> -> tensor<128xf32, 1 : i64>
22+
// CHECK-NEXT: %[[v4:.+]] = "some.use"(%[[v3]]) : (tensor<128xf32, 1 : i64>) -> tensor<128xf32, 1 : i64>
23+
// CHECK-NEXT: %[[v5:.+]] = bufferization.to_memref %[[v4]] : tensor<128xf32, 1 : i64> -> memref<128xf32, strided<[?], offset: ?>, 1>
24+
// CHECK-NEXT: scf.yield %[[v5]] : memref<128xf32, strided<[?], offset: ?>, 1>
25+
// CHECK: %[[v2:.+]] = bufferization.to_tensor %[[v1]] : memref<128xf32, strided<[?], offset: ?>, 1> -> tensor<128xf32, 1 : i64>
26+
// CHECK: return %[[v2]] : tensor<128xf32, 1 : i64>
27+
28+
// -----
29+
30+
func.func @scf_forall(
31+
%idx: index,
32+
%idx2: index,
33+
%arg1: tensor<?xf32, 1>,
34+
%arg2: tensor<?xf32, 1>) -> (tensor<?xf32, 1>) {
35+
%cst = arith.constant 4.200000e+01 : f32
36+
%c0 = arith.constant 0 : index
37+
%c1 = arith.constant 1 : index
38+
%2 = scf.forall (%arg3) in (%idx2) shared_outs(%o = %arg2) -> (tensor<?xf32, 1>) {
39+
%8 = "some.use"(%o) : (tensor<?xf32, 1>) -> tensor<?xf32, 1>
40+
scf.forall.in_parallel {
41+
tensor.parallel_insert_slice %8 into %o[5] [%idx] [%c1] :
42+
tensor<?xf32, 1> into tensor<?xf32, 1>
43+
}
44+
}
45+
return %2 : tensor<?xf32, 1>
46+
}
47+
48+
// CHECK-LABEL: func.func @scf_forall
49+
// CHECK: scf.forall
50+
// CHECK: %[[v2:.+]] = bufferization.to_tensor %{{.+}} : memref<?xf32, 1> -> tensor<?xf32, 1 : i64>
51+
// CHECK: %[[v3:.+]] = "some.use"(%[[v2]]) : (tensor<?xf32, 1 : i64>) -> tensor<?xf32, 1 : i64>
52+
// CHECK: bufferization.to_memref %[[v3]] : tensor<?xf32, 1 : i64> -> memref<?xf32, strided<[?], offset: ?>, 1>
53+
// CHECK: %[[v1:.+]] = bufferization.to_tensor %{{.+}} : memref<?xf32, 1> -> tensor<?xf32, 1 : i64>
54+
// CHECK: return %[[v1]] : tensor<?xf32, 1 : i64>
55+
56+
// -----
57+
58+
func.func @scf_execute_region(%arg0: tensor<128xf32, 1>) -> tensor<128xf32, 1> {
59+
%0 = scf.execute_region -> tensor<128xf32, 1> {
60+
scf.yield %arg0 : tensor<128xf32, 1>
61+
}
62+
%1 = "some.use"(%0) : (tensor<128xf32, 1>) -> tensor<128xf32, 1>
63+
return %1 : tensor<128xf32, 1>
64+
}
65+
66+
// CHECK-LABEL: func.func @scf_execute_region
67+
// CHECK-SAME: (%[[arg0:.+]]: tensor<128xf32, 1 : i64>)
68+
// CHECK: %[[v0:.+]] = bufferization.to_memref %[[arg0]] : tensor<128xf32, 1 : i64> -> memref<128xf32, strided<[?], offset: ?>, 1>
69+
// CHECK: %[[v1:.+]] = scf.execute_region -> memref<128xf32, strided<[?], offset: ?>, 1>
70+
// CHECK: scf.yield %[[v0]] : memref<128xf32, strided<[?], offset: ?>, 1>
71+
// CHECK: %[[v2:.+]] = bufferization.to_tensor %[[v1]] : memref<128xf32, strided<[?], offset: ?>, 1> -> tensor<128xf32, 1 : i64>
72+
// CHECK: %[[v3:.+]] = "some.use"(%[[v2]]) : (tensor<128xf32, 1 : i64>) -> tensor<128xf32, 1 : i64>
73+
// CHECK: return %[[v3]] : tensor<128xf32, 1 : i64>
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: mlir-opt %s -one-shot-bufferize="use-encoding-for-memory-space" -split-input-file | FileCheck %s
2+
3+
func.func @from_elements(%fill: f32, %f: f32, %idx: index) -> tensor<3xf32, 1> {
4+
%t = tensor.from_elements %fill, %fill, %fill : tensor<3xf32, 1>
5+
%i = tensor.insert %f into %t[%idx] : tensor<3xf32, 1>
6+
return %i : tensor<3xf32, 1>
7+
}
8+
9+
// CHECK-LABEL: @from_elements
10+
// CHECK-SAME: (%[[arg0:.+]]: f32, %[[arg1:.+]]: f32, %[[arg2:.+]]: index) -> tensor<3xf32, 1 : i64>
11+
// CHECK: %[[alloc:.+]] = memref.alloc() {{.*}} : memref<3xf32, 1>
12+
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
13+
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
14+
// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index
15+
// CHECK: memref.store %[[arg0]], %[[alloc]][%[[c0]]] : memref<3xf32, 1>
16+
// CHECK: memref.store %[[arg0]], %[[alloc]][%[[c1]]] : memref<3xf32, 1>
17+
// CHECK: memref.store %[[arg0]], %[[alloc]][%[[c2]]] : memref<3xf32, 1>
18+
// CHECK: memref.store %[[arg1]], %[[alloc]][%[[arg2]]] : memref<3xf32, 1>
19+
// CHECK: %[[v0:.+]] = bufferization.to_tensor %[[alloc]] : memref<3xf32, 1> -> tensor<3xf32, 1 : i64>
20+
// CHECK: return %[[v0]] : tensor<3xf32, 1 : i64>

0 commit comments

Comments
 (0)