Skip to content

Commit a1bc979

Browse files
[mlir][Bufferization] Do not have read semantics for destination of tensor.parallel_insert_slice. (#134169)
`tensor.insert_slice` needs to have read semantics on its destination operand. Since it has a return value, its semantics are - Copy dest to result - Copy source to subview of destination. `tensor.parallel_insert_slice` though has no result. So it does not need to have read semantics. The op description [here](https://github.com/llvm/llvm-project/blob/a3ac318e5f8668ec5b79dd86639881dfb2e88b69/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td#L1524) also says that it is expected to lower to a `memref.subview`, that does not have read semantics on the destination (its just a view). This patch drops the read semantics for destination of `tensor.parallel_insert_slice` but also makes the `shared_outs` operands of `scf.forall` have read semantics. Earlier it would rely indirectly on read semantics of destination operand of `tensor.parallel_insert_slice` to propagate the read semantics for `shared_outs`. Now that is specified more directly. Fixes #133964 --------- Signed-off-by: MaheshRavishankar <[email protected]>
1 parent bc6cd82 commit a1bc979

File tree

3 files changed

+41
-25
lines changed

3 files changed

+41
-25
lines changed

mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,18 +1186,6 @@ struct YieldOpInterface
11861186
}
11871187
};
11881188

1189-
/// Return `true` if the given loop may have 0 iterations.
1190-
bool mayHaveZeroIterations(scf::ForallOp forallOp) {
1191-
for (auto [lb, ub] : llvm::zip(forallOp.getMixedLowerBound(),
1192-
forallOp.getMixedUpperBound())) {
1193-
std::optional<int64_t> lbConst = getConstantIntValue(lb);
1194-
std::optional<int64_t> ubConst = getConstantIntValue(ub);
1195-
if (!lbConst.has_value() || !ubConst.has_value() || *lbConst >= *ubConst)
1196-
return true;
1197-
}
1198-
return false;
1199-
}
1200-
12011189
/// Bufferization of ForallOp. This also bufferizes the terminator of the
12021190
/// region. There are op interfaces for the terminators (InParallelOp
12031191
/// and ParallelInsertSliceOp), but these are only used during analysis. Not
@@ -1207,17 +1195,11 @@ struct ForallOpInterface
12071195
ForallOp> {
12081196
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
12091197
const AnalysisState &state) const {
1210-
auto forallOp = cast<ForallOp>(op);
1211-
1212-
// If the loop has zero iterations, the results of the op are their
1213-
// corresponding shared_outs, meaning that the shared_outs bufferize to a
1214-
// read.
1215-
if (mayHaveZeroIterations(forallOp))
1216-
return true;
1217-
1218-
// scf::ForallOp alone doesn't bufferize to a memory read, one of the
1219-
// uses of its matching bbArg may.
1220-
return state.isValueRead(forallOp.getTiedBlockArgument(&opOperand));
1198+
// All tensor operands to `scf.forall` are `shared_outs` and all
1199+
// shared outs are assumed to be read by the loop. This does not
1200+
// account for the case where the entire value is over-written,
1201+
// but being conservative here.
1202+
return true;
12211203
}
12221204

12231205
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -930,8 +930,7 @@ struct ParallelInsertSliceOpInterface
930930

931931
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
932932
const AnalysisState &state) const {
933-
return insertSliceOpRequiresRead(cast<tensor::ParallelInsertSliceOp>(op),
934-
opOperand);
933+
return opOperand == cast<ParallelInsertSliceOp>(op).getSourceMutable();
935934
}
936935

937936
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,

mlir/test/Dialect/SCF/one-shot-bufferize.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -946,3 +946,38 @@ func.func @index_switch(%pred: index, %b: tensor<5xf32>, %c: tensor<5xf32>) -> t
946946
// CHECK: return %[[r]]
947947
return %0 : tensor<5xf32>
948948
}
949+
950+
// -----
951+
952+
// See Issue https://github.com/llvm/llvm-project/issues/133964 . Checks that
953+
// tensor.parallel_insert_slice dest operand does not have read semantics.
954+
func.func @check_scfforall_inplace_bufferizer(%arg0 : tensor<?x?xf32>,
955+
%arg1 : tensor<?x?xf32>,
956+
%arg2 : tensor<?xf32> {bufferization.writable = true}) -> tensor<?xf32> {
957+
%c0 = arith.constant 0 : index
958+
%c1 = arith.constant 1 : index
959+
%d0 = tensor.dim %arg2, %c0 : tensor<?xf32>
960+
%d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
961+
%0 = scf.forall (%arg3) in (%c1) shared_outs(%arg4 = %arg2) -> (tensor<?xf32>) {
962+
%1 = tensor.extract_slice %arg0[0, 0][%d0, %d1][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
963+
%2 = tensor.extract_slice %arg1[0, 0][%d0, %d1][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
964+
%3 = linalg.generic {
965+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
966+
affine_map<(d0, d1) -> (d0, d1)>,
967+
affine_map<(d0, d1) -> (d0)>],
968+
iterator_types = ["parallel", "reduction"]}
969+
ins(%1, %2 : tensor<?x?xf32>, tensor<?x?xf32>)
970+
outs(%arg4 : tensor<?xf32>) {
971+
^bb0(%b0 : f32, %b1: f32, %b2 : f32):
972+
%4 = arith.mulf %b0, %b1 : f32
973+
%5 = arith.addf %4, %b2 : f32
974+
linalg.yield %5 : f32
975+
} -> tensor<?xf32>
976+
scf.forall.in_parallel {
977+
tensor.parallel_insert_slice %3 into %arg4[0] [%d0] [1] : tensor<?xf32> into tensor<?xf32>
978+
}
979+
}
980+
return %0 : tensor<?xf32>
981+
}
982+
// CHECK-LABEL: func @check_scfforall_inplace_bufferizer
983+
// CHECK-NOT: memref.alloc

0 commit comments

Comments
 (0)