Skip to content

Commit 98e838a

Browse files
authored
[mlir] Do not bufferize parallel_insert_slice dest to read for full slices (#112761)
In the insert_slice bufferization interface implementation, the destination tensor is not considered read if the full tensor is overwritten by the slice. This PR adds the same check for tensor.parallel_insert_slice. Adds two new StaticValueUtils: - `isAllConstantIntValue` checks if an array of `OpFoldResult` are all equal to a passed `int64_t` value. - `areConstantIntValues` checks if an array of `OpFoldResult` are all equal to a passed array of `int64_t` values. fixes #112435 --------- Signed-off-by: Max Dawkins <[email protected]>
1 parent f148d57 commit 98e838a

File tree

5 files changed

+62
-33
lines changed

5 files changed

+62
-33
lines changed

mlir/include/mlir/Dialect/Utils/StaticValueUtils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@ getConstantIntValues(ArrayRef<OpFoldResult> ofrs);
9292

9393
/// Return true if `ofr` is constant integer equal to `value`.
9494
bool isConstantIntValue(OpFoldResult ofr, int64_t value);
95+
/// Return true if all of `ofrs` are constant integers equal to `value`.
96+
bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value);
97+
/// Return true if all of `ofrs` are constant integers equal to the
98+
/// corresponding value in `values`.
99+
bool areConstantIntValues(ArrayRef<OpFoldResult> ofrs,
100+
ArrayRef<int64_t> values);
95101

96102
/// Return true if ofr1 and ofr2 are the same integer constant attribute
97103
/// values or the same SSA value. Ignore integer bitwitdh and type mismatch

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2020
#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
2121
#include "mlir/Dialect/Utils/StaticValueUtils.h"
22+
#include "mlir/IR/BuiltinTypeInterfaces.h"
2223
#include "mlir/IR/Dialect.h"
2324
#include "mlir/IR/Operation.h"
2425

@@ -636,6 +637,28 @@ struct InsertOpInterface
636637
}
637638
};
638639

640+
template <typename InsertOpTy>
641+
static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp,
642+
OpOperand &opOperand) {
643+
// The source is always read.
644+
if (opOperand == insertSliceOp.getSourceMutable())
645+
return true;
646+
647+
// For the destination, it depends...
648+
assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
649+
650+
// Dest is not read if it is entirely overwritten. E.g.:
651+
// tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
652+
bool allOffsetsZero =
653+
llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroIndex);
654+
RankedTensorType destType = insertSliceOp.getDestType();
655+
bool sizesMatchDestSizes =
656+
areConstantIntValues(insertSliceOp.getMixedSizes(), destType.getShape());
657+
bool allStridesOne =
658+
areAllConstantIntValue(insertSliceOp.getMixedStrides(), 1);
659+
return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
660+
}
661+
639662
/// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
640663
/// certain circumstances, this op can also be a no-op.
641664
///
@@ -646,32 +669,8 @@ struct InsertSliceOpInterface
646669
tensor::InsertSliceOp> {
647670
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
648671
const AnalysisState &state) const {
649-
auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
650-
RankedTensorType destType = insertSliceOp.getDestType();
651-
652-
// The source is always read.
653-
if (opOperand == insertSliceOp.getSourceMutable())
654-
return true;
655-
656-
// For the destination, it depends...
657-
assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
658-
659-
// Dest is not read if it is entirely overwritten. E.g.:
660-
// tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
661-
bool allOffsetsZero =
662-
llvm::all_of(insertSliceOp.getMixedOffsets(), [](OpFoldResult ofr) {
663-
return isConstantIntValue(ofr, 0);
664-
});
665-
bool sizesMatchDestSizes = llvm::all_of(
666-
llvm::enumerate(insertSliceOp.getMixedSizes()), [&](const auto &it) {
667-
return getConstantIntValue(it.value()) ==
668-
destType.getDimSize(it.index());
669-
});
670-
bool allStridesOne =
671-
llvm::all_of(insertSliceOp.getMixedStrides(), [](OpFoldResult ofr) {
672-
return isConstantIntValue(ofr, 1);
673-
});
674-
return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
672+
return insertSliceOpRequiresRead(cast<tensor::InsertSliceOp>(op),
673+
opOperand);
675674
}
676675

677676
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -931,7 +930,8 @@ struct ParallelInsertSliceOpInterface
931930

932931
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
933932
const AnalysisState &state) const {
934-
return true;
933+
return insertSliceOpRequiresRead(cast<tensor::ParallelInsertSliceOp>(op),
934+
opOperand);
935935
}
936936

937937
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,

mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,6 @@ namespace mlir {
1616
namespace tensor {
1717
namespace {
1818

19-
static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
20-
return llvm::all_of(
21-
ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
22-
}
23-
2419
/// Returns the number of shape sizes that is either dynamic or greater than 1.
2520
static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
2621
return llvm::count_if(

mlir/lib/Dialect/Utils/StaticValueUtils.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "mlir/IR/Matchers.h"
1111
#include "mlir/Support/LLVM.h"
1212
#include "llvm/ADT/APSInt.h"
13+
#include "llvm/ADT/STLExtras.h"
1314
#include "llvm/Support/MathExtras.h"
1415

1516
namespace mlir {
@@ -131,12 +132,24 @@ getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
131132
return res;
132133
}
133134

134-
/// Return true if `ofr` is constant integer equal to `value`.
135135
bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
136136
auto val = getConstantIntValue(ofr);
137137
return val && *val == value;
138138
}
139139

140+
bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
141+
return llvm::all_of(
142+
ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
143+
}
144+
145+
bool areConstantIntValues(ArrayRef<OpFoldResult> ofrs,
146+
ArrayRef<int64_t> values) {
147+
if (ofrs.size() != values.size())
148+
return false;
149+
std::optional<SmallVector<int64_t>> constOfrs = getConstantIntValues(ofrs);
150+
return constOfrs && llvm::equal(constOfrs.value(), values);
151+
}
152+
140153
/// Return true if ofr1 and ofr2 are the same integer constant attribute values
141154
/// or the same SSA value.
142155
/// Ignore integer bitwidth and type mismatch that come from the fact there is

mlir/test/Dialect/Tensor/one-shot-bufferize.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,21 @@ func.func @rank_reducing_parallel_insert_slice(%in: tensor<100xf32>, %out: tenso
213213

214214
// -----
215215

216+
// CHECK-LABEL: func.func @parallel_insert_full_slice_in_place
217+
// CHECK-NOT: memref.alloc()
218+
func.func @parallel_insert_full_slice_in_place(%2: tensor<2xf32>) -> tensor<2xf32> {
219+
%cst = arith.constant 0.000000e+00 : f32
220+
%3 = scf.forall (%arg0) in (1) shared_outs(%arg2 = %2) -> (tensor<2xf32>) {
221+
%fill = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<2xf32>) -> tensor<2xf32>
222+
scf.forall.in_parallel {
223+
tensor.parallel_insert_slice %fill into %arg2[0] [2] [1] : tensor<2xf32> into tensor<2xf32>
224+
}
225+
} {mapping = [#gpu.thread<linear_dim_0>]}
226+
return %3 : tensor<2xf32>
227+
}
228+
229+
// -----
230+
216231
// This test case could bufferize in-place with a better analysis. However, it
217232
// is simpler to let the canonicalizer fold away the tensor.insert_slice.
218233

0 commit comments

Comments
 (0)