Skip to content

Commit 7dd7078

Browse files
[mlir][linalg][bufferize] Handle scf::ForOp correctly in bufferizesToMemoryRead
From the perspective of analysis, scf::ForOp is treated as a black box. Basic block arguments do not alias with their respective OpOperands on the ForOp, so they do not participate in conflict analysis with ops defined outside of the loop. However, bufferizesToMemoryRead and bufferizesToMemoryWrite on the scf::ForOp itself are used to determine how the scf::ForOp interacts with its surrounding ops. Differential Revision: https://reviews.llvm.org/D111775
1 parent d3cb6bf commit 7dd7078

File tree

2 files changed

+128
-9
lines changed

2 files changed

+128
-9
lines changed

mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,31 @@ static OpResult getAliasingOpResult(OpOperand &opOperand) {
612612
[&](Operation *op) { return getInplaceableOpResult(opOperand); });
613613
}
614614

615+
// Predeclaration of function.
616+
static bool bufferizesToMemoryRead(OpOperand &opOperand);
617+
618+
/// scf::ForOp alone doesn't bufferize to a memory read, one of the uses of its
619+
/// matching bbArg may.
620+
static bool bufferizesToMemoryRead(scf::ForOp forOp, OpOperand &opOperand) {
621+
SmallVector<OpOperand *> workingSet;
622+
for (OpOperand &use : forOp.getRegionIterArgForOpOperand(opOperand).getUses())
623+
workingSet.push_back(&use);
624+
625+
while (!workingSet.empty()) {
626+
OpOperand *uMaybeReading = workingSet.pop_back_val();
627+
// Skip over all ExtractSliceOps. These do not read by themselves but just
628+
// add a new alias.
629+
if (auto extractSliceOp =
630+
dyn_cast<ExtractSliceOp>(uMaybeReading->getOwner()))
631+
for (OpOperand &use : extractSliceOp.result().getUses())
632+
workingSet.push_back(&use);
633+
if (bufferizesToMemoryRead(*uMaybeReading))
634+
return true;
635+
}
636+
637+
return false;
638+
}
639+
615640
/// Return true if `opOperand` bufferizes to a memory read.
616641
static bool bufferizesToMemoryRead(OpOperand &opOperand) {
617642
// Unknown op that returns a tensor. The inplace analysis does not support
@@ -622,15 +647,8 @@ static bool bufferizesToMemoryRead(OpOperand &opOperand) {
622647
// may.
623648
if (isa<ExtractSliceOp>(opOperand.getOwner()))
624649
return false;
625-
// scf::ForOp alone doesn't bufferize to a memory read, one of the uses of its
626-
// matching bbArg may.
627-
if (auto forOp = dyn_cast<scf::ForOp>(opOperand.getOwner())) {
628-
for (OpOperand &use :
629-
forOp.getRegionIterArgForOpOperand(opOperand).getUses())
630-
if (bufferizesToMemoryRead(use))
631-
return true;
632-
return false;
633-
}
650+
if (auto forOp = dyn_cast<scf::ForOp>(opOperand.getOwner()))
651+
return bufferizesToMemoryRead(forOp, opOperand);
634652
// TiledLoop alone doesn't bufferize to a memory read, one of the uses of its
635653
// matching bbArg may.
636654
if (auto tiledLoopOp = dyn_cast<TiledLoopOp>(opOperand.getOwner())) {

mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -912,3 +912,104 @@ func @interleaved_extract_insert_slice_chain_2(
912912

913913
return %15 : tensor<62x90xf32>
914914
}
915+
916+
// -----
917+
918+
#accesses = [
919+
affine_map<(i) -> (i)>
920+
]
921+
#trait = {
922+
indexing_maps = #accesses,
923+
iterator_types = ["parallel"]
924+
}
925+
926+
// CHECK-LABEL: func @reading_scf_for
927+
func @reading_scf_for(%t1: tensor<?xf32> {linalg.inplaceable = true},
928+
%s: index, %v: vector<5xf32>) -> (tensor<?xf32>, vector<5xf32>) {
929+
930+
%c0 = arith.constant 0 : index
931+
%c1 = arith.constant 1 : index
932+
%cst = arith.constant 0.0 : f32
933+
934+
// Write to %t1.
935+
// CHECK: vector.transfer_write
936+
// CHECK-SAME: __inplace_results_attr__ = ["false"]
937+
%t3 = vector.transfer_write %v, %t1[%s] : vector<5xf32>, tensor<?xf32>
938+
939+
// Read the old value of %t1 inside the loop via an alias.
940+
// CHECK: scf.for
941+
%r, %v3 = scf.for %i = %c0 to %s step %c1 iter_args(%t2 = %t1, %v0 = %v) -> (tensor<?xf32>, vector<5xf32>) {
942+
// CHECK: tensor.extract_slice
943+
// CHECK-SAME: __inplace_results_attr__ = ["true"]
944+
%e = tensor.extract_slice %t2[%s][%s][1] : tensor<?xf32> to tensor<?xf32>
945+
946+
// Read from %t1 via alias %e.
947+
%v2 = vector.transfer_read %e[%s], %cst : tensor<?xf32>, vector<5xf32>
948+
scf.yield %e, %v2 : tensor<?xf32>, vector<5xf32>
949+
}
950+
// CHECK: __inplace_results_attr__ = ["true", "none"]
951+
952+
// Use %t3 in some way without reading it, so that it does not get DCE'd.
953+
// CHECK: linalg.generic
954+
// CHECK-SAME: __inplace_results_attr__ = ["true"]
955+
%o = linalg.generic #trait outs (%t3 : tensor<?xf32>) {
956+
^bb(%0: f32) :
957+
linalg.yield %cst : f32
958+
} -> (tensor<?xf32>)
959+
960+
return %o, %v3 : tensor<?xf32>, vector<5xf32>
961+
}
962+
963+
// -----
964+
965+
#accesses = [
966+
affine_map<(i) -> (i)>
967+
]
968+
#trait = {
969+
indexing_maps = #accesses,
970+
iterator_types = ["parallel"]
971+
}
972+
973+
// CHECK-LABEL: func @non_reading_scf_for
974+
func @non_reading_scf_for(%t1: tensor<?xf32> {linalg.inplaceable = true},
975+
%s: index, %v: vector<5xf32>) -> (tensor<?xf32>, vector<5xf32>) {
976+
977+
%c0 = arith.constant 0 : index
978+
%c1 = arith.constant 1 : index
979+
%cst = arith.constant 0.0 : f32
980+
981+
// Write to %t1.
982+
// CHECK: vector.transfer_write
983+
// CHECK-SAME: __inplace_results_attr__ = ["true"]
984+
%t3 = vector.transfer_write %v, %t1[%s] : vector<5xf32>, tensor<?xf32>
985+
986+
// This loop does not read from %t1. It only writes to it.
987+
// CHECK: scf.for
988+
%r, %v3 = scf.for %i = %c0 to %s step %c1 iter_args(%t2 = %t1, %v0 = %v) -> (tensor<?xf32>, vector<5xf32>) {
989+
// CHECK: tensor.extract_slice
990+
// CHECK-SAME: __inplace_results_attr__ = ["true"]
991+
%e = tensor.extract_slice %t2[%s][%s][1] : tensor<?xf32> to tensor<?xf32>
992+
993+
// Write to %t1 via alias. (Overwrite %t3.)
994+
// CHECK: linalg.generic
995+
// CHECK-SAME: __inplace_results_attr__ = ["true"]
996+
%o2 = linalg.generic #trait outs (%e : tensor<?xf32>) {
997+
^bb(%0: f32) :
998+
linalg.yield %cst : f32
999+
} -> (tensor<?xf32>)
1000+
1001+
// Read overwritten value. This is not a read of %t1.
1002+
%v2 = vector.transfer_read %o2[%s], %cst : tensor<?xf32>, vector<5xf32>
1003+
scf.yield %o2, %v2 : tensor<?xf32>, vector<5xf32>
1004+
}
1005+
1006+
// Use %t3 in some way without reading it, so that it does not get DCE'd.
1007+
// CHECK: linalg.generic
1008+
// CHECK-SAME: __inplace_results_attr__ = ["true"]
1009+
%o = linalg.generic #trait outs (%t3 : tensor<?xf32>) {
1010+
^bb(%0: f32) :
1011+
linalg.yield %cst : f32
1012+
} -> (tensor<?xf32>)
1013+
1014+
return %o, %v3 : tensor<?xf32>, vector<5xf32>
1015+
}

0 commit comments

Comments
 (0)