Skip to content

Commit bf77747

Browse files
Add additional tests for scf.for and scf.execute_region
1 parent 645ed9d commit bf77747

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,8 @@ struct ExecuteRegionOpInterface
203203
for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
204204
if (isa<TensorType>(it.value())) {
205205
newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
206-
executeRegionOp.getLoc(), newOp->getResult(it.index())));
206+
executeRegionOp.getLoc(), it.value(),
207+
newOp->getResult(it.index())));
207208
} else {
208209
newResults.push_back(newOp->getResult(it.index()));
209210
}

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-encodings.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,22 @@ func.func @materialize_in_destination(%arg0: tensor<128xf32, 1>) -> tensor<128xf
131131
// CHECK: memref.copy %[[v0]], %[[alloc]] : memref<128xf32, strided<[?], offset: ?>, 1> to memref<128xf32, 2>
132132
// CHECK: %[[v1:.+]] = bufferization.to_tensor %[[alloc]] : memref<128xf32, 2> -> tensor<128xf32, 2 : i64>
133133
// CHECK: return %[[v1]] : tensor<128xf32, 2 : i64>
134+
135+
// -----
136+
137+
func.func @scf_for_iter_arg(%arg0: tensor<128xf32, 1>, %arg1: index, %arg2: index, %arg3: index, %arg4: f32) -> tensor<128xf32, 1> {
138+
%0 = scf.for %i = %arg1 to %arg2 step %arg3 iter_args(%iter = %arg0) -> tensor<128xf32, 1> {
139+
%0 = tensor.insert %arg4 into %iter[%i] : tensor<128xf32, 1>
140+
scf.yield %0 : tensor<128xf32, 1>
141+
}
142+
return %0 : tensor<128xf32, 1>
143+
}
144+
145+
// -----
146+
147+
func.func @scf_execute_region(%arg0: tensor<128xf32, 1>) -> tensor<128xf32, 1> {
148+
%0 = scf.execute_region -> tensor<128xf32, 1> {
149+
scf.yield %arg0 : tensor<128xf32, 1>
150+
}
151+
return %0 : tensor<128xf32, 1>
152+
}

0 commit comments

Comments
 (0)