Skip to content

Commit 7116ece

Browse files
committed
[mlir] Do not bufferize parallel_insert_slice dest to read for full slices
Signed-off-by: Max Dawkins <[email protected]>
1 parent b7bc1d0 commit 7116ece

File tree

2 files changed

+48
-27
lines changed

2 files changed

+48
-27
lines changed

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2020
#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
2121
#include "mlir/Dialect/Utils/StaticValueUtils.h"
22+
#include "mlir/IR/BuiltinTypeInterfaces.h"
2223
#include "mlir/IR/Dialect.h"
2324
#include "mlir/IR/Operation.h"
2425

@@ -636,6 +637,34 @@ struct InsertOpInterface
636637
}
637638
};
638639

640+
template <typename InsertOpTy>
641+
static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp,
642+
OpOperand &opOperand) {
643+
RankedTensorType destType = insertSliceOp.getDestType();
644+
645+
// The source is always read.
646+
if (opOperand == insertSliceOp.getSourceMutable())
647+
return true;
648+
649+
// For the destination, it depends...
650+
assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
651+
652+
// Dest is not read if it is entirely overwritten. E.g.:
653+
// tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
654+
bool allOffsetsZero =
655+
llvm::all_of(insertSliceOp.getMixedOffsets(),
656+
[](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
657+
bool sizesMatchDestSizes = llvm::all_of(
658+
llvm::enumerate(insertSliceOp.getMixedSizes()), [&](const auto &it) {
659+
return getConstantIntValue(it.value()) ==
660+
destType.getDimSize(it.index());
661+
});
662+
bool allStridesOne =
663+
llvm::all_of(insertSliceOp.getMixedStrides(),
664+
[](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
665+
return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
666+
}
667+
639668
/// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
640669
/// certain circumstances, this op can also be a no-op.
641670
///
@@ -646,32 +675,8 @@ struct InsertSliceOpInterface
646675
tensor::InsertSliceOp> {
647676
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
648677
const AnalysisState &state) const {
649-
auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
650-
RankedTensorType destType = insertSliceOp.getDestType();
651-
652-
// The source is always read.
653-
if (opOperand == insertSliceOp.getSourceMutable())
654-
return true;
655-
656-
// For the destination, it depends...
657-
assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
658-
659-
// Dest is not read if it is entirely overwritten. E.g.:
660-
// tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
661-
bool allOffsetsZero =
662-
llvm::all_of(insertSliceOp.getMixedOffsets(), [](OpFoldResult ofr) {
663-
return isConstantIntValue(ofr, 0);
664-
});
665-
bool sizesMatchDestSizes = llvm::all_of(
666-
llvm::enumerate(insertSliceOp.getMixedSizes()), [&](const auto &it) {
667-
return getConstantIntValue(it.value()) ==
668-
destType.getDimSize(it.index());
669-
});
670-
bool allStridesOne =
671-
llvm::all_of(insertSliceOp.getMixedStrides(), [](OpFoldResult ofr) {
672-
return isConstantIntValue(ofr, 1);
673-
});
674-
return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
678+
return insertSliceOpRequiresRead(cast<tensor::InsertSliceOp>(op),
679+
opOperand);
675680
}
676681

677682
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -931,7 +936,8 @@ struct ParallelInsertSliceOpInterface
931936

932937
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
933938
const AnalysisState &state) const {
934-
return true;
939+
return insertSliceOpRequiresRead(cast<tensor::ParallelInsertSliceOp>(op),
940+
opOperand);
935941
}
936942

937943
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,

mlir/test/Dialect/Tensor/one-shot-bufferize.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,21 @@ func.func @rank_reducing_parallel_insert_slice(%in: tensor<100xf32>, %out: tenso
213213

214214
// -----
215215

216+
// CHECK-LABEL: func.func @parallel_insert_full_slice_in_place
217+
// CHECK-NOT: memref.alloc()
218+
func.func @parallel_insert_full_slice_in_place(%2: tensor<2xf32>) -> tensor<2xf32> {
219+
%cst = arith.constant 0.000000e+00 : f32
220+
%3 = scf.forall (%arg0) in (1) shared_outs(%arg2 = %2) -> (tensor<2xf32>) {
221+
%fill = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<2xf32>) -> tensor<2xf32>
222+
scf.forall.in_parallel {
223+
tensor.parallel_insert_slice %fill into %arg2[0] [2] [1] : tensor<2xf32> into tensor<2xf32>
224+
}
225+
} {mapping = [#gpu.thread<linear_dim_0>]}
226+
return %3 : tensor<2xf32>
227+
}
228+
229+
// -----
230+
216231
// This test case could bufferize in-place with a better analysis. However, it
217232
// is simpler to let the canonicalizer fold away the tensor.insert_slice.
218233

0 commit comments

Comments
 (0)